This is an automated email from the ASF dual-hosted git repository.

karan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 4322ff8849e Push down datasource to realtime tasks (#18235)
4322ff8849e is described below

commit 4322ff8849e92accd58b952b5fbdcd5d58456830
Author: Adarsh Sanjeev <[email protected]>
AuthorDate: Sat Jul 19 21:30:55 2025 +0530

    Push down datasource to realtime tasks (#18235)
    
    * Add test
    
    * Fix setup
    
    * Tests
    
    * Add test
    
    * Write unnest base
    
    * Fix
    
    * Dart tests
    
    * Remove duplicate modules
    
    * Add unnest test
    
    * Address review comments
    
    * Fix failed test
    
    * Address review comments
    
    * Update error message
    
    * Refactor
    
    * Refactor
    
    * Add case for restricted datasource and tests
    
    * Fix broken mocks
    
    * Fix yet more broken mocks
---
 .../embedded/msq/BaseRealtimeQueryTest.java        | 196 +++++++
 .../msq/EmbeddedDurableShuffleStorageTest.java     |  12 +-
 .../testing/embedded/msq/EmbeddedMSQApis.java      |  17 +-
 .../embedded/msq/EmbeddedMSQRealtimeQueryTest.java | 567 +++++++++++++++------
 .../msq/EmbeddedMSQRealtimeUnnestQueryTest.java    | 174 +++++++
 .../dart/worker/DartDataServerQueryHandler.java    |  13 +-
 .../worker/DartDataServerQueryHandlerFactory.java  |   6 +-
 .../msq/exec/DataServerQueryHandlerFactory.java    |   3 +-
 .../msq/exec/DataServerQueryHandlerUtils.java      |  68 ++-
 .../indexing/IndexerDataServerQueryHandler.java    |  17 +-
 .../IndexerDataServerQueryHandlerFactory.java      |   6 +-
 .../msq/input/table/SegmentsInputSliceReader.java  |   3 +
 .../msq/exec/DataServerQueryHandlerUtilsTest.java  |  53 ++
 .../IndexerDataServerQueryHandlerTest.java         |   4 +-
 .../druid/msq/test/CalciteMSQTestsHelper.java      |  28 +-
 .../org/apache/druid/msq/test/MSQTestBase.java     |   9 +-
 .../testing/embedded/EmbeddedClusterApis.java      |  12 +-
 17 files changed, 975 insertions(+), 213 deletions(-)

diff --git 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/BaseRealtimeQueryTest.java
 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/BaseRealtimeQueryTest.java
new file mode 100644
index 00000000000..3a05637f660
--- /dev/null
+++ 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/BaseRealtimeQueryTest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.testing.embedded.msq;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import it.unimi.dsi.fastutil.bytes.ByteArrays;
+import org.apache.druid.data.input.impl.DimensionsSpec;
+import org.apache.druid.data.input.impl.JsonInputFormat;
+import org.apache.druid.data.input.impl.TimestampSpec;
+import org.apache.druid.frame.testutil.FrameTestUtil;
+import org.apache.druid.indexer.TaskStatusPlus;
+import org.apache.druid.indexer.granularity.UniformGranularitySpec;
+import org.apache.druid.indexing.kafka.KafkaIndexTaskModule;
+import org.apache.druid.indexing.kafka.simulate.KafkaResource;
+import org.apache.druid.indexing.kafka.supervisor.KafkaSupervisorIOConfig;
+import org.apache.druid.indexing.kafka.supervisor.KafkaSupervisorSpec;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.parsers.CloseableIterator;
+import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.QueryableIndexCursorFactory;
+import org.apache.druid.segment.TestHelper;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.indexing.DataSchema;
+import org.apache.druid.testing.embedded.EmbeddedClusterApis;
+import org.apache.druid.testing.embedded.EmbeddedDruidCluster;
+import org.apache.druid.testing.embedded.junit5.EmbeddedClusterTestBase;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.joda.time.Period;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+
+/**
+ * Base test for Kafka related embedded test.
+ */
+public class BaseRealtimeQueryTest extends EmbeddedClusterTestBase
+{
+  private static final Period TASK_DURATION = Period.hours(1);
+  private static final int TASK_COUNT = 2;
+
+  protected KafkaResource kafka;
+  protected String topic;
+
+  @Override
+  protected EmbeddedDruidCluster createCluster()
+  {
+    kafka = new KafkaResource();
+    return EmbeddedDruidCluster
+        .withEmbeddedDerbyAndZookeeper()
+        .addExtension(KafkaIndexTaskModule.class)
+        .addResource(kafka);
+  }
+
+  @BeforeEach
+  void setupCreateKafkaTopic()
+  {
+    // Create Kafka topic.
+    topic = EmbeddedClusterApis.createTestDatasourceName();
+    kafka.createTopicWithPartitions(topic, 2);
+  }
+
+  /**
+   * Submits a supervisor spec to the Overlord.
+   */
+  protected void submitSupervisor()
+  {
+    // Submit a supervisor.
+    final KafkaSupervisorSpec kafkaSupervisorSpec = createKafkaSupervisor();
+    final Map<String, String> startSupervisorResult =
+        cluster.callApi().onLeaderOverlord(o -> 
o.postSupervisor(kafkaSupervisorSpec));
+    Assertions.assertEquals(Map.of("id", dataSource), startSupervisorResult);
+  }
+
+  /**
+   * Publishes data from a {@link QueryableIndex} to Kafka.
+   */
+  protected void publishToKafka(QueryableIndex index)
+  {
+    // Send data to Kafka.
+    final QueryableIndexCursorFactory wikiCursorFactory =
+        new QueryableIndexCursorFactory(index);
+    final RowSignature wikiSignature = wikiCursorFactory.getRowSignature();
+    kafka.produceRecordsToTopic(
+        FrameTestUtil.readRowsFromCursorFactory(wikiCursorFactory)
+                     .map(row -> {
+                       final Map<String, Object> rowMap = new 
LinkedHashMap<>();
+                       for (int i = 0; i < row.size(); i++) {
+                         rowMap.put(wikiSignature.getColumnName(i), 
row.get(i));
+                       }
+                       try {
+                         return new ProducerRecord<>(
+                             topic,
+                             ByteArrays.EMPTY_ARRAY,
+                             TestHelper.JSON_MAPPER.writeValueAsBytes(rowMap)
+                         );
+                       }
+                       catch (JsonProcessingException e) {
+                         throw new RuntimeException(e);
+                       }
+                     })
+                     .toList()
+    );
+  }
+
+  @AfterEach
+  void tearDownEach() throws ExecutionException, InterruptedException, 
IOException
+  {
+    final Map<String, String> terminateSupervisorResult =
+        cluster.callApi().onLeaderOverlord(o -> 
o.terminateSupervisor(dataSource));
+    Assertions.assertEquals(Map.of("id", dataSource), 
terminateSupervisorResult);
+
+    // Cancel all running tasks, so we don't need to wait for them to hand off 
their segments.
+    try (final CloseableIterator<TaskStatusPlus> it = 
cluster.leaderOverlord().taskStatuses(null, null, null).get()) {
+      while (it.hasNext()) {
+        cluster.leaderOverlord().cancelTask(it.next().getId());
+      }
+    }
+
+    kafka.deleteTopic(topic);
+  }
+
+  private KafkaSupervisorSpec createKafkaSupervisor()
+  {
+    final Period startDelay = Period.millis(10);
+    final Period supervisorRunPeriod = Period.millis(500);
+    final boolean useEarliestOffset = true;
+
+    return new KafkaSupervisorSpec(
+        dataSource,
+        null,
+        DataSchema.builder()
+                  .withDataSource(dataSource)
+                  .withTimestamp(new TimestampSpec("__time", "auto", null))
+                  .withGranularity(new 
UniformGranularitySpec(Granularities.DAY, null, null))
+                  
.withDimensions(DimensionsSpec.builder().useSchemaDiscovery(true).build())
+                  .build(),
+        null,
+        new KafkaSupervisorIOConfig(
+            topic,
+            null,
+            new JsonInputFormat(null, null, null, null, null),
+            null,
+            TASK_COUNT,
+            TASK_DURATION,
+            kafka.consumerProperties(),
+            null,
+            null,
+            null,
+            startDelay,
+            supervisorRunPeriod,
+            useEarliestOffset,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        ),
+        null,
+        null,
+        null,
+        null,
+        null,
+        null,
+        null,
+        null,
+        null,
+        null,
+        null
+    );
+  }
+}
diff --git 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedDurableShuffleStorageTest.java
 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedDurableShuffleStorageTest.java
index e3f07a1afe2..9a7aba6bf8d 100644
--- 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedDurableShuffleStorageTest.java
+++ 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedDurableShuffleStorageTest.java
@@ -153,7 +153,7 @@ public class EmbeddedDurableShuffleStorageTest extends 
EmbeddedClusterTestBase
         dataSource
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     BaseCalciteQueryTest.assertResultsEquals(
         sql,
@@ -175,7 +175,7 @@ public class EmbeddedDurableShuffleStorageTest extends 
EmbeddedClusterTestBase
         dataSource
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     BaseCalciteQueryTest.assertResultsEquals(
         sql,
@@ -203,7 +203,7 @@ public class EmbeddedDurableShuffleStorageTest extends 
EmbeddedClusterTestBase
         dataSource
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     BaseCalciteQueryTest.assertResultsEquals(
         sql,
@@ -237,7 +237,7 @@ public class EmbeddedDurableShuffleStorageTest extends 
EmbeddedClusterTestBase
         dataSource
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     BaseCalciteQueryTest.assertResultsEquals(
         sql,
@@ -274,7 +274,7 @@ public class EmbeddedDurableShuffleStorageTest extends 
EmbeddedClusterTestBase
         dataSource
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     BaseCalciteQueryTest.assertResultsEquals(
         sql,
@@ -359,7 +359,7 @@ public class EmbeddedDurableShuffleStorageTest extends 
EmbeddedClusterTestBase
         )
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
     Assertions.assertEquals(TaskState.SUCCESS, 
payload.getStatus().getStatus());
     Assertions.assertEquals(1, 
payload.getStatus().getSegmentLoadWaiterStatus().getTotalSegments());
     Assertions.assertNull(payload.getStatus().getErrorReport());
diff --git 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQApis.java
 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQApis.java
index a53b42bb3ee..4c002039f4b 100644
--- 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQApis.java
+++ 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQApis.java
@@ -80,11 +80,11 @@ public class EmbeddedMSQApis
 
   /**
    * Submits the given SQL query to any of the brokers (using {@code 
BrokerClient})
-   * of the cluster. Waits for it to complete, then returns the query report.
+   * of the cluster, checks that the task has started and returns the {@link 
SqlTaskStatus}.
    *
    * @return The result of the SQL as a single CSV string.
    */
-  public MSQTaskReportPayload runTaskSql(String sql, Object... args)
+  public SqlTaskStatus submitTaskSql(String sql, Object... args)
   {
     final SqlTaskStatus taskStatus =
         FutureUtils.getUnchecked(
@@ -110,6 +110,19 @@ public class EmbeddedMSQApis
       );
     }
 
+    return taskStatus;
+  }
+
+  /**
+   * Submits the given SQL query to any of the brokers (using {@code 
BrokerClient})
+   * of the cluster. Waits for it to complete, then returns the query report.
+   *
+   * @return The result of the SQL as a single CSV string.
+   */
+  public MSQTaskReportPayload runTaskSqlAndGetReport(String sql, Object... 
args)
+  {
+    SqlTaskStatus taskStatus = submitTaskSql(sql, args);
+
     cluster.callApi().waitForTaskToSucceed(taskStatus.getTaskId(), overlord);
 
     final TaskReport.ReportMap taskReport = FutureUtils.getUnchecked(
diff --git 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQRealtimeQueryTest.java
 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQRealtimeQueryTest.java
index 8ac2d13c237..c2433b86bf3 100644
--- 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQRealtimeQueryTest.java
+++ 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQRealtimeQueryTest.java
@@ -19,21 +19,9 @@
 
 package org.apache.druid.testing.embedded.msq;
 
-import com.fasterxml.jackson.core.JsonProcessingException;
-import it.unimi.dsi.fastutil.bytes.ByteArrays;
-import org.apache.druid.data.input.impl.DimensionsSpec;
-import org.apache.druid.data.input.impl.JsonInputFormat;
-import org.apache.druid.data.input.impl.TimestampSpec;
-import org.apache.druid.frame.testutil.FrameTestUtil;
-import org.apache.druid.indexer.TaskStatusPlus;
-import org.apache.druid.indexer.granularity.UniformGranularitySpec;
-import org.apache.druid.indexing.kafka.KafkaIndexTaskModule;
-import org.apache.druid.indexing.kafka.simulate.KafkaResource;
-import org.apache.druid.indexing.kafka.supervisor.KafkaSupervisorIOConfig;
-import org.apache.druid.indexing.kafka.supervisor.KafkaSupervisorSpec;
+import org.apache.druid.client.indexing.TaskStatusResponse;
+import org.apache.druid.indexer.TaskState;
 import org.apache.druid.java.util.common.StringUtils;
-import org.apache.druid.java.util.common.granularity.Granularities;
-import org.apache.druid.java.util.common.parsers.CloseableIterator;
 import org.apache.druid.msq.dart.guice.DartControllerMemoryManagementModule;
 import org.apache.druid.msq.dart.guice.DartControllerModule;
 import org.apache.druid.msq.dart.guice.DartWorkerMemoryManagementModule;
@@ -45,11 +33,8 @@ import org.apache.druid.msq.guice.MSQSqlModule;
 import org.apache.druid.msq.guice.SqlTaskModule;
 import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
 import org.apache.druid.query.DruidMetrics;
-import org.apache.druid.segment.QueryableIndexCursorFactory;
-import org.apache.druid.segment.TestHelper;
+import org.apache.druid.query.http.SqlTaskStatus;
 import org.apache.druid.segment.TestIndex;
-import org.apache.druid.segment.column.RowSignature;
-import org.apache.druid.segment.indexing.DataSchema;
 import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
 import org.apache.druid.testing.embedded.EmbeddedBroker;
 import org.apache.druid.testing.embedded.EmbeddedClusterApis;
@@ -59,30 +44,44 @@ import org.apache.druid.testing.embedded.EmbeddedHistorical;
 import org.apache.druid.testing.embedded.EmbeddedIndexer;
 import org.apache.druid.testing.embedded.EmbeddedOverlord;
 import org.apache.druid.testing.embedded.EmbeddedRouter;
-import org.apache.druid.testing.embedded.junit5.EmbeddedClusterTestBase;
-import org.apache.kafka.clients.producer.ProducerRecord;
-import org.joda.time.Period;
-import org.junit.jupiter.api.AfterEach;
+import org.hamcrest.CoreMatchers;
+import org.hamcrest.MatcherAssert;
+import org.junit.internal.matchers.ThrowableMessageMatcher;
 import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Timeout;
 
-import java.io.IOException;
 import java.util.Collections;
-import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
-import java.util.concurrent.ExecutionException;
 
 /**
  * Embedded test to ingest {@link TestIndex#getMMappedWikipediaIndex()} into 
Kafka tasks, then query
  * those tasks with MSQ.
  */
-public class EmbeddedMSQRealtimeQueryTest extends EmbeddedClusterTestBase
+public class EmbeddedMSQRealtimeQueryTest extends BaseRealtimeQueryTest
 {
-  private static final Period TASK_DURATION = Period.hours(1);
-  private static final int TASK_COUNT = 2;
+  private static final String LOOKUP_TABLE = "test_lookup";
+  private static final String BULK_UPDATE_LOOKUP_PAYLOAD
+      = "{\n"
+        + "  \"__default\": {\n"
+        + "    \"%s\": {\n"
+        + "      \"version\": \"v1\",\n"
+        + "      \"lookupExtractorFactory\": {\n"
+        + "        \"type\": \"map\",\n"
+        + "        \"map\": {\n"
+        + "          \"#en.wikipedia\": \"English\",\n"
+        + "          \"#fr.wikipedia\": \"French\",\n"
+        + "          \"#eu.wikipedia\": \"European\",\n"
+        + "          \"#ar.wikipedia\": \"Arabic\",\n"
+        + "          \"#cs.wikipedia\": \"Czech\"\n"
+        + "        }\n"
+        + "      }\n"
+        + "    }\n"
+        + "  }\n"
+        + "}";
 
   private final EmbeddedBroker broker = new EmbeddedBroker();
   private final EmbeddedIndexer indexer = new EmbeddedIndexer();
@@ -92,14 +91,12 @@ public class EmbeddedMSQRealtimeQueryTest extends 
EmbeddedClusterTestBase
   private final EmbeddedRouter router = new EmbeddedRouter();
   private final int totalRows = 
TestIndex.getMMappedWikipediaIndex().getNumRows();
 
-  private KafkaResource kafka;
-  private String topic;
   private EmbeddedMSQApis msqApis;
 
   @Override
   public EmbeddedDruidCluster createCluster()
   {
-    kafka = new KafkaResource();
+    EmbeddedDruidCluster clusterWithKafka = super.createCluster();
 
     coordinator.addProperty("druid.manager.segments.useIncrementalCache", 
"always");
 
@@ -110,25 +107,23 @@ public class EmbeddedMSQRealtimeQueryTest extends 
EmbeddedClusterTestBase
           .addProperty("druid.query.default.context.maxConcurrentStages", "1");
 
     historical.addProperty("druid.msq.dart.worker.heapFraction", "0.9")
-              .addProperty("druid.msq.dart.worker.concurrentQueries", "1");
+              .addProperty("druid.msq.dart.worker.concurrentQueries", "1")
+              .addProperty("druid.lookup.enableLookupSyncOnStartup", "true");
 
-    indexer.setServerMemory(300_000_000) // to run 2x realtime and 2x MSQ tasks
+    indexer.setServerMemory(400_000_000) // to run 2x realtime and 2x MSQ tasks
            .addProperty("druid.segment.handoff.pollDuration", "PT0.1s")
            // druid.processing.numThreads must be higher than # of MSQ tasks 
to avoid contention, because the realtime
            // server is contacted in such a way that the processing thread is 
blocked
            .addProperty("druid.processing.numThreads", "3")
-           .addProperty("druid.worker.capacity", "4");
+           .addProperty("druid.worker.capacity", "4")
+           .addProperty("druid.lookup.enableLookupSyncOnStartup", "true");
 
-    return EmbeddedDruidCluster
-        .withEmbeddedDerbyAndZookeeper()
+    return clusterWithKafka
         .addExtensions(
-            KafkaIndexTaskModule.class,
             DartControllerModule.class,
             DartWorkerModule.class,
             DartControllerMemoryManagementModule.class,
-            DartControllerModule.class,
             DartWorkerMemoryManagementModule.class,
-            DartWorkerModule.class,
             IndexerMemoryManagementModule.class,
             MSQDurableStorageModule.class,
             MSQIndexingModule.class,
@@ -138,54 +133,42 @@ public class EmbeddedMSQRealtimeQueryTest extends 
EmbeddedClusterTestBase
         .addCommonProperty("druid.monitoring.emissionPeriod", "PT0.1s")
         .addCommonProperty("druid.msq.dart.enabled", "true")
         .useLatchableEmitter()
-        .addResource(kafka)
         .addServer(coordinator)
         .addServer(overlord)
-        .addServer(indexer)
-        .addServer(broker)
-        .addServer(historical)
         .addServer(router);
   }
 
+  @BeforeAll
+  protected void setupLookups() throws Exception
+  {
+    // Initialize lookups
+    cluster.callApi().onLeaderCoordinator(
+        c -> c.updateAllLookups(Map.of())
+    );
+
+    final String lookupPayload = StringUtils.format(
+        BULK_UPDATE_LOOKUP_PAYLOAD,
+        LOOKUP_TABLE
+    );
+    cluster.callApi().onLeaderCoordinator(
+        c -> 
c.updateAllLookups(EmbeddedClusterApis.deserializeJsonToMap(lookupPayload))
+    );
+
+    // Initialize the broker/data-servers later so that lookups are loaded on 
startup.
+    cluster.addServer(broker)
+           .addServer(indexer)
+           .addServer(historical);
+    broker.start();
+    indexer.start();
+    historical.start();
+  }
+
   @BeforeEach
   void setUpEach()
   {
     msqApis = new EmbeddedMSQApis(cluster, overlord);
-    topic = dataSource = EmbeddedClusterApis.createTestDatasourceName();
-
-    // Create Kafka topic.
-    kafka.createTopicWithPartitions(topic, 2);
-
-    // Submit a supervisor.
-    final KafkaSupervisorSpec kafkaSupervisorSpec = createKafkaSupervisor();
-    final Map<String, String> startSupervisorResult =
-        cluster.callApi().onLeaderOverlord(o -> 
o.postSupervisor(kafkaSupervisorSpec));
-    Assertions.assertEquals(Map.of("id", dataSource), startSupervisorResult);
-
-    // Send data to Kafka.
-    final QueryableIndexCursorFactory wikiCursorFactory =
-        new QueryableIndexCursorFactory(TestIndex.getMMappedWikipediaIndex());
-    final RowSignature wikiSignature = wikiCursorFactory.getRowSignature();
-    kafka.produceRecordsToTopic(
-        FrameTestUtil.readRowsFromCursorFactory(wikiCursorFactory)
-                     .map(row -> {
-                       final Map<String, Object> rowMap = new 
LinkedHashMap<>();
-                       for (int i = 0; i < row.size(); i++) {
-                         rowMap.put(wikiSignature.getColumnName(i), 
row.get(i));
-                       }
-                       try {
-                         return new ProducerRecord<>(
-                             topic,
-                             ByteArrays.EMPTY_ARRAY,
-                             TestHelper.JSON_MAPPER.writeValueAsBytes(rowMap)
-                         );
-                       }
-                       catch (JsonProcessingException e) {
-                         throw new RuntimeException(e);
-                       }
-                     })
-                     .toList()
-    );
+    submitSupervisor();
+    publishToKafka(TestIndex.getMMappedWikipediaIndex());
 
     // Wait for it to be loaded.
     indexer.latchableEmitter().waitForEventAggregate(
@@ -195,29 +178,12 @@ public class EmbeddedMSQRealtimeQueryTest extends 
EmbeddedClusterTestBase
     );
   }
 
-  @AfterEach
-  void tearDownEach() throws ExecutionException, InterruptedException, 
IOException
-  {
-    final Map<String, String> terminateSupervisorResult =
-        cluster.callApi().onLeaderOverlord(o -> 
o.terminateSupervisor(dataSource));
-    Assertions.assertEquals(Map.of("id", dataSource), 
terminateSupervisorResult);
-
-    // Cancel all running tasks, so we don't need to wait for them to hand off 
their segments.
-    try (final CloseableIterator<TaskStatusPlus> it = 
cluster.leaderOverlord().taskStatuses(null, null, null).get()) {
-      while (it.hasNext()) {
-        cluster.leaderOverlord().cancelTask(it.next().getId());
-      }
-    }
-
-    kafka.deleteTopic(topic);
-  }
-
   @Test
   @Timeout(60)
   public void test_selectCount_task_default()
   {
     final String sql = StringUtils.format("SELECT COUNT(*) FROM \"%s\"", 
dataSource);
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     // By default tasks do not include realtime data; count is zero.
     BaseCalciteQueryTest.assertResultsEquals(
@@ -237,7 +203,7 @@ public class EmbeddedMSQRealtimeQueryTest extends 
EmbeddedClusterTestBase
         dataSource
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     BaseCalciteQueryTest.assertResultsEquals(
         sql,
@@ -273,12 +239,111 @@ public class EmbeddedMSQRealtimeQueryTest extends 
EmbeddedClusterTestBase
 
   @Test
   @Timeout(60)
-  @Disabled // Test does not currently pass, see 
https://github.com/apache/druid/issues/18198
-  public void test_selectJoin_dart()
+  public void test_selectBroadcastJoin_dart()
+  {
+    final String sql = "SELECT COUNT(*) FROM \"%s\"\n"
+                       + "WHERE countryName IN (\n"
+                       + "  SELECT countryName\n"
+                       + "  FROM \"%s\"\n"
+                       + "  WHERE countryName IS NOT NULL\n"
+                       + "  GROUP BY 1\n"
+                       + "  ORDER BY COUNT(*) DESC\n"
+                       + "  LIMIT 1\n"
+                       + ")";
+
+    MatcherAssert.assertThat(
+        Assertions.assertThrows(
+            RuntimeException.class,
+            () -> msqApis.runDartSql(sql, dataSource, dataSource)
+        ),
+        ThrowableMessageMatcher.hasMessage(
+            CoreMatchers.containsString(
+                "Cannot handle stage with multiple sources while querying 
realtime data. If using broadcast "
+                + "joins, try setting[sqlJoinAlgorithm] to[sortMerge] in your 
query context."
+            )
+        )
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_selectBroadcastJoin_task_withRealtime()
+  {
+    final String sql = StringUtils.format(
+        "SET includeSegmentSource = 'REALTIME';\n"
+        + "SELECT COUNT(*) FROM \"%s\"\n"
+        + "WHERE countryName IN (\n"
+        + "  SELECT countryName\n"
+        + "  FROM \"%s\"\n"
+        + "  WHERE countryName IS NOT NULL\n"
+        + "  GROUP BY 1\n"
+        + "  ORDER BY COUNT(*) DESC\n"
+        + "  LIMIT 1\n"
+        + ")",
+        dataSource,
+        dataSource
+    );
+
+    SqlTaskStatus taskStatus = msqApis.submitTaskSql(sql);
+
+    String taskId = taskStatus.getTaskId();
+    cluster.callApi().waitForTaskToFinish(taskId, overlord);
+
+    final TaskStatusResponse currentStatus = 
cluster.callApi().onLeaderOverlord(
+        o -> o.taskStatus(taskId)
+    );
+    Assertions.assertNotNull(currentStatus.getStatus());
+    Assertions.assertEquals(
+        TaskState.FAILED,
+        currentStatus.getStatus().getStatusCode(),
+        StringUtils.format("Task[%s] has unexpected status", taskId)
+    );
+
+    Assertions.assertTrue(
+        CoreMatchers.containsString(
+            "Cannot handle stage with multiple sources while querying realtime 
data. If using broadcast "
+            + "joins, try setting[sqlJoinAlgorithm] to[sortMerge] in your 
query context."
+        ).matches(currentStatus.getStatus().getErrorMsg())
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_selectSortMergeJoin_task_withRealtime()
+  {
+    final String sql = StringUtils.format(
+        "SET includeSegmentSource = 'REALTIME';"
+        + "SET sqlJoinAlgorithm = 'sortMerge';\n"
+        + "SELECT COUNT(*) FROM \"%s\"\n"
+        + "WHERE countryName IN (\n"
+        + "  SELECT countryName\n"
+        + "  FROM \"%s\"\n"
+        + "  WHERE countryName IS NOT NULL\n"
+        + "  GROUP BY 1\n"
+        + "  ORDER BY COUNT(*) DESC\n"
+        + "  LIMIT 1\n"
+        + ")",
+        dataSource,
+        dataSource
+    );
+
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
+
+    BaseCalciteQueryTest.assertResultsEquals(
+        sql,
+        Collections.singletonList(new Object[]{528}),
+        payload.getResults().getResults()
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_selectSortMergeJoin_dart()
   {
     final long selectedCount = Long.parseLong(
         msqApis.runDartSql(
-            "SELECT COUNT(*) FROM \"%s\"\n"
+            "SET sqlJoinAlgorithm = 'sortMerge';\n"
+            + "SELECT COUNT(*) FROM \"%s\"\n"
             + "WHERE countryName IN (\n"
             + "  SELECT countryName\n"
             + "  FROM \"%s\"\n"
@@ -297,83 +362,265 @@ public class EmbeddedMSQRealtimeQueryTest extends 
EmbeddedClusterTestBase
 
   @Test
   @Timeout(60)
-  @Disabled // Test does not currently pass, see 
https://github.com/apache/druid/issues/18198
-  public void test_selectJoin_task_withRealtime()
+  public void test_selectJoinwithLookup_task_withRealtime()
   {
     final String sql = StringUtils.format(
         "SET includeSegmentSource = 'REALTIME';\n"
-        + "SELECT COUNT(*) FROM \"%s\"\n"
-        + "WHERE countryName IN (\n"
-        + "  SELECT countryName\n"
-        + "  FROM \"%s\"\n"
-        + "  WHERE countryName IS NOT NULL\n"
-        + "  GROUP BY 1\n"
-        + "  ORDER BY COUNT(*) DESC\n"
-        + "  LIMIT 1\n"
-        + ")",
+        + "SELECT \n"
+        + " l.v AS newName, \n"
+        + " SUM(w.\"added\") AS total\n"
+        + "FROM \"%s\" w INNER JOIN lookup.%s l ON w.\"channel\" = l.k\n"
+        + "GROUP BY 1\n"
+        + "ORDER BY 2 DESC\n",
+        dataSource,
+        LOOKUP_TABLE
+    );
+
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
+
+    BaseCalciteQueryTest.assertResultsEquals(
+        sql,
+        List.of(
+            new Object[]{"English", 3045299},
+            new Object[]{"French", 642555},
+            new Object[]{"Arabic", 153605},
+            new Object[]{"Czech", 132768},
+            new Object[]{"European", 6690}
+        ),
+        payload.getResults().getResults()
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_selectJoinwithLookup_dart()
+  {
+    final String sql = StringUtils.format(
+        "SELECT \n"
+        + " l.v AS newName, \n"
+        + " SUM(w.\"added\") AS total\n"
+        + "FROM \"%s\" w INNER JOIN lookup.%s l ON w.\"channel\" = l.k\n"
+        + "GROUP BY 1\n"
+        + "ORDER BY 2 DESC\n",
         dataSource,
+        LOOKUP_TABLE
+    );
+
+    final String result = msqApis.runDartSql(sql);
+
+    Assertions.assertEquals(
+        "English,3045299\n"
+        + "French,642555\n"
+        + "Arabic,153605\n"
+        + "Czech,132768\n"
+        + "European,6690",
+        result
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_selectJoinWithConcatVirtualDimension_task_withRealtime()
+  {
+    final String sql = StringUtils.format(
+        "SET includeSegmentSource = 'REALTIME';\n"
+        + "SELECT\n"
+        + "  \"channel\",\n"
+        + "  \"countryIsoCode\",\n"
+        + "  CONCAT(w.\"cityName\", ': ', l.v),\n"
+        + "  \"user\"\n"
+        + "FROM %s w\n"
+        + "  INNER JOIN lookup.%s l ON w.\"channel\" = l.k\n"
+        + "WHERE\n"
+        + "  w.\"cityName\" IS NOT NULL\n"
+        + "  AND \"added\" > 1000 AND \"delta\" > 5000\n"
+        + "ORDER BY 3 DESC\n",
+        dataSource,
+        LOOKUP_TABLE
+    );
+
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
+
+    BaseCalciteQueryTest.assertResultsEquals(
+        sql,
+        List.of(
+            new Object[]{"#en.wikipedia", "GB", "London: English", 
"78.145.31.93"},
+            new Object[]{"#ar.wikipedia", "AE", "Dubai: Arabic", "86.98.5.51"},
+            new Object[]{"#en.wikipedia", "IN", "Bhopal: English", 
"14.139.241.50"}
+        ),
+        payload.getResults().getResults()
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_selectJoinWithConcatVirtualDimension_dart()
+  {
+    final String sql = StringUtils.format(
+        "SELECT\n"
+        + "  \"channel\",\n"
+        + "  \"countryIsoCode\",\n"
+        + "  CONCAT(w.\"cityName\", ': ', l.v),\n"
+        + "  \"user\"\n"
+        + "FROM %s w\n"
+        + "  INNER JOIN lookup.%s l ON w.\"channel\" = l.k\n"
+        + "WHERE\n"
+        + "  w.\"cityName\" IS NOT NULL\n"
+        + "  AND \"added\" > 1000 AND \"delta\" > 5000\n"
+        + "ORDER BY 3 DESC\n",
+        dataSource,
+        LOOKUP_TABLE
+    );
+
+    final String results = msqApis.runDartSql(sql);
+
+    Assertions.assertEquals(
+        "#en.wikipedia,GB,London: English,78.145.31.93\n"
+        + "#ar.wikipedia,AE,Dubai: Arabic,86.98.5.51\n"
+        + "#en.wikipedia,IN,Bhopal: English,14.139.241.50",
+        results
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_scanWithFilter_task_withRealtime()
+  {
+    final String sql = StringUtils.format(
+        "SET includeSegmentSource = 'REALTIME';\n"
+        + "SELECT \"channel\", \"page\", \"user\", \"deleted\"\n"
+        + "FROM \"%s\"\n"
+        + "WHERE \"cityName\" = 'Sydney' AND \"delta\" > 10",
         dataSource
     );
 
-    final MSQTaskReportPayload payload = msqApis.runTaskSql(sql);
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
 
     BaseCalciteQueryTest.assertResultsEquals(
         sql,
-        Collections.singletonList(new Object[]{528}),
+        List.of(
+            new Object[]{"#en.wikipedia", "Coca-Cola formula", 
"124.169.17.234", 0},
+            new Object[]{"#en.wikipedia", "List of Harry Potter characters", 
"121.211.82.121", 0}
+        ),
         payload.getResults().getResults()
     );
   }
 
-  private KafkaSupervisorSpec createKafkaSupervisor()
+
+  @Test
+  @Timeout(60)
+  public void test_scanWithFilter_dart()
   {
-    final Period startDelay = Period.millis(10);
-    final Period supervisorRunPeriod = Period.millis(500);
-    final boolean useEarliestOffset = true;
+    final String sql = StringUtils.format(
+        "SELECT \"channel\", \"page\", \"user\", \"deleted\"\n"
+        + "FROM \"%s\"\n"
+        + "WHERE \"cityName\" = 'Sydney' AND \"delta\" > 10",
+        dataSource
+    );
+
+    final String result = msqApis.runDartSql(sql);
+
+    Assertions.assertEquals(
+        "#en.wikipedia,Coca-Cola formula,124.169.17.234,0\n#en.wikipedia,List 
of Harry Potter characters,121.211.82.121,0",
+        result
+    );
+  }
 
-    return new KafkaSupervisorSpec(
+  @Test
+  @Timeout(60)
+  public void test_groupByWithFilter_task_withRealtime()
+  {
+    final String sql = StringUtils.format(
+        "SET includeSegmentSource = 'REALTIME';\n"
+        + "SELECT \"channel\", COUNT(*)\n"
+        + "FROM \"%s\"\n"
+        + "WHERE \"countryName\" = 'Australia'\n"
+        + "GROUP BY 1\n"
+        + "ORDER BY 1 DESC",
+        dataSource
+    );
+
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
+
+    BaseCalciteQueryTest.assertResultsEquals(
+        sql,
+        List.of(
+            new Object[]{"#en.wikipedia", 63},
+            new Object[]{"#de.wikipedia", 2}
+        ),
+        payload.getResults().getResults()
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_groupByWithFilter_dart()
+  {
+    final String sql = StringUtils.format(
+        "SELECT \"channel\", COUNT(*)\n"
+        + "FROM \"%s\"\n"
+        + "WHERE \"countryName\" = 'Australia'\n"
+        + "GROUP BY 1\n"
+        + "ORDER BY 1 DESC",
+        dataSource
+    );
+
+    final String result = msqApis.runDartSql(sql);
+
+    Assertions.assertEquals("#en.wikipedia,63\n#de.wikipedia,2", result);
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_scanWithFilterAfterJoin_task_withRealtime()
+  {
+    final String sql = StringUtils.format(
+        "SET includeSegmentSource = 'REALTIME';\n"
+        + "SELECT \n"
+        + "  \"page\", \n"
+        + "  \"user\", \n"
+        + "  \"added\"\n"
+        + "FROM %s w\n"
+        + "  INNER JOIN lookup.%s l ON w.\"channel\" = l.k\n"
+        + "WHERE CONCAT(w.\"cityName\", ': ', l.v) = 'London: English' AND 
\"comment\" IN ('/* Works */', '/* Early life */')\n",
         dataSource,
-        null,
-        DataSchema.builder()
-                  .withDataSource(dataSource)
-                  .withTimestamp(new TimestampSpec("__time", "auto", null))
-                  .withGranularity(new 
UniformGranularitySpec(Granularities.DAY, null, null))
-                  
.withDimensions(DimensionsSpec.builder().useSchemaDiscovery(true).build())
-                  .build(),
-        null,
-        new KafkaSupervisorIOConfig(
-            topic,
-            null,
-            new JsonInputFormat(null, null, null, null, null),
-            null,
-            TASK_COUNT,
-            TASK_DURATION,
-            kafka.consumerProperties(),
-            null,
-            null,
-            null,
-            startDelay,
-            supervisorRunPeriod,
-            useEarliestOffset,
-            null,
-            null,
-            null,
-            null,
-            null,
-            null,
-            null,
-            null
+        LOOKUP_TABLE
+    );
+
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
+
+    BaseCalciteQueryTest.assertResultsEquals(
+        sql,
+        List.of(
+            new Object[]{"Andy Wilman", "109.156.217.121", 0},
+            new Object[]{"Angharad Rees", "89.240.46.182", 578},
+            new Object[]{"Chazz Palminteri", "81.178.229.60", 10}
         ),
-        null,
-        null,
-        null,
-        null,
-        null,
-        null,
-        null,
-        null,
-        null,
-        null,
-        null
+        payload.getResults().getResults()
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_scanWithFilterAfterJoin_dart()
+  {
+    final String sql = StringUtils.format(
+        "SELECT \n"
+        + "  \"page\", \n"
+        + "  \"user\", \n"
+        + "  \"added\"\n"
+        + "FROM %s w\n"
+        + "  INNER JOIN lookup.%s l ON w.\"channel\" = l.k\n"
+        + "WHERE CONCAT(w.\"cityName\", ': ', l.v) = 'London: English' AND 
\"comment\" IN ('/* Works */', '/* Early life */')\n",
+        dataSource,
+        LOOKUP_TABLE
+    );
+
+    final String result = msqApis.runDartSql(sql);
+
+    Assertions.assertEquals(
+        "Andy Wilman,109.156.217.121,0\nAngharad Rees,89.240.46.182,578\nChazz 
Palminteri,81.178.229.60,10",
+        result
     );
   }
 }
diff --git 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQRealtimeUnnestQueryTest.java
 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQRealtimeUnnestQueryTest.java
new file mode 100644
index 00000000000..cb38ce54acc
--- /dev/null
+++ 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/msq/EmbeddedMSQRealtimeUnnestQueryTest.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.testing.embedded.msq;
+
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.msq.dart.guice.DartControllerMemoryManagementModule;
+import org.apache.druid.msq.dart.guice.DartControllerModule;
+import org.apache.druid.msq.dart.guice.DartWorkerMemoryManagementModule;
+import org.apache.druid.msq.dart.guice.DartWorkerModule;
+import org.apache.druid.msq.guice.IndexerMemoryManagementModule;
+import org.apache.druid.msq.guice.MSQDurableStorageModule;
+import org.apache.druid.msq.guice.MSQIndexingModule;
+import org.apache.druid.msq.guice.MSQSqlModule;
+import org.apache.druid.msq.guice.SqlTaskModule;
+import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
+import org.apache.druid.query.DruidMetrics;
+import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.TestIndex;
+import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
+import org.apache.druid.testing.embedded.EmbeddedBroker;
+import org.apache.druid.testing.embedded.EmbeddedCoordinator;
+import org.apache.druid.testing.embedded.EmbeddedDruidCluster;
+import org.apache.druid.testing.embedded.EmbeddedHistorical;
+import org.apache.druid.testing.embedded.EmbeddedIndexer;
+import org.apache.druid.testing.embedded.EmbeddedOverlord;
+import org.apache.druid.testing.embedded.EmbeddedRouter;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import java.util.Collections;
+import java.util.List;
+
+public class EmbeddedMSQRealtimeUnnestQueryTest extends BaseRealtimeQueryTest
+{
+  private final EmbeddedBroker broker = new EmbeddedBroker();
+  private final EmbeddedIndexer indexer = new EmbeddedIndexer();
+  private final EmbeddedOverlord overlord = new EmbeddedOverlord();
+  private final EmbeddedHistorical historical = new EmbeddedHistorical();
+  private final EmbeddedCoordinator coordinator = new EmbeddedCoordinator();
+  private final EmbeddedRouter router = new EmbeddedRouter();
+
+  private EmbeddedMSQApis msqApis;
+
+  @Override
+  public EmbeddedDruidCluster createCluster()
+  {
+    EmbeddedDruidCluster clusterWithKafka = super.createCluster();
+
+    coordinator.addProperty("druid.manager.segments.useIncrementalCache", 
"always");
+
+    overlord.addProperty("druid.manager.segments.useIncrementalCache", 
"always")
+            .addProperty("druid.manager.segments.pollDuration", "PT0.1s");
+
+    broker.addProperty("druid.msq.dart.controller.heapFraction", "0.9")
+          .addProperty("druid.query.default.context.maxConcurrentStages", "1");
+
+    historical.addProperty("druid.msq.dart.worker.heapFraction", "0.9")
+              .addProperty("druid.msq.dart.worker.concurrentQueries", "1");
+
+    indexer.setServerMemory(300_000_000) // to run 2x realtime and 2x MSQ tasks
+           .addProperty("druid.segment.handoff.pollDuration", "PT0.1s")
+           // druid.processing.numThreads must be higher than # of MSQ tasks 
to avoid contention, because the realtime
+           // server is contacted in such a way that the processing thread is 
blocked
+           .addProperty("druid.processing.numThreads", "3")
+           .addProperty("druid.worker.capacity", "4");
+
+    return clusterWithKafka
+        .addExtensions(
+            DartControllerModule.class,
+            DartWorkerModule.class,
+            DartControllerMemoryManagementModule.class,
+            DartWorkerMemoryManagementModule.class,
+            IndexerMemoryManagementModule.class,
+            MSQDurableStorageModule.class,
+            MSQIndexingModule.class,
+            MSQSqlModule.class,
+            SqlTaskModule.class
+        )
+        .addCommonProperty("druid.monitoring.emissionPeriod", "PT0.1s")
+        .addCommonProperty("druid.msq.dart.enabled", "true")
+        .useLatchableEmitter()
+        .addServer(coordinator)
+        .addServer(overlord)
+        .addServer(router)
+        .addServer(broker)
+        .addServer(historical)
+        .addServer(indexer);
+  }
+
+  @BeforeEach
+  void setUpEach()
+  {
+    msqApis = new EmbeddedMSQApis(cluster, overlord);
+
+    QueryableIndex index = TestIndex.getMMappedTestIndex();
+
+    submitSupervisor();
+    publishToKafka(index);
+
+    final int totalRows = index.getNumRows();
+
+    // Wait for it to be loaded.
+    indexer.latchableEmitter().waitForEventAggregate(
+        event -> event.hasMetricName("ingest/events/processed")
+                      .hasDimension(DruidMetrics.DATASOURCE, 
Collections.singletonList(dataSource)),
+        agg -> agg.hasSumAtLeast(totalRows)
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_unnest_task_withRealtime()
+  {
+    final String sql = StringUtils.format(
+        "SET includeSegmentSource = 'REALTIME';\n"
+        + "SELECT d3 FROM \"%s\" CROSS JOIN 
UNNEST(MV_TO_ARRAY(\"placementish\")) AS d3\n"
+        + "LIMIT 5",
+        dataSource
+    );
+    final MSQTaskReportPayload payload = msqApis.runTaskSqlAndGetReport(sql);
+
+    BaseCalciteQueryTest.assertResultsEquals(
+        sql,
+        List.of(
+            new Object[]{"a"},
+            new Object[]{"preferred"},
+            new Object[]{"b"},
+            new Object[]{"preferred"},
+            new Object[]{"e"}
+        ),
+        payload.getResults().getResults()
+    );
+  }
+
+  @Test
+  @Timeout(60)
+  public void test_unnest_dart()
+  {
+    final String sql = StringUtils.format(
+        "SELECT d3 FROM \"%s\" CROSS JOIN 
UNNEST(MV_TO_ARRAY(\"placementish\")) AS d3\n"
+        + "LIMIT 5",
+        dataSource
+    );
+    final String result = msqApis.runDartSql(sql);
+
+    Assertions.assertEquals(
+        "a\n"
+        + "preferred\n"
+        + "b\n"
+        + "preferred\n"
+        + "e",
+        result
+    );
+  }
+}
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandler.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandler.java
index 4a6b67da669..244a462bf83 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandler.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandler.java
@@ -56,7 +56,8 @@ import java.util.stream.Collectors;
  */
 public class DartDataServerQueryHandler implements DataServerQueryHandler
 {
-  private final String dataSource;
+  private final int inputNumber;
+  private final String dataSourceName;
   private final ChannelCounters channelCounters;
   private final ServiceClientFactory serviceClientFactory;
   private final ObjectMapper objectMapper;
@@ -64,7 +65,8 @@ public class DartDataServerQueryHandler implements 
DataServerQueryHandler
   private final DataServerRequestDescriptor requestDescriptor;
 
   public DartDataServerQueryHandler(
-      String dataSource,
+      int inputNumber,
+      String dataSourceName,
       ChannelCounters channelCounters,
       ServiceClientFactory serviceClientFactory,
       ObjectMapper objectMapper,
@@ -72,7 +74,8 @@ public class DartDataServerQueryHandler implements 
DataServerQueryHandler
       DataServerRequestDescriptor requestDescriptor
   )
   {
-    this.dataSource = dataSource;
+    this.inputNumber = inputNumber;
+    this.dataSourceName = dataSourceName;
     this.channelCounters = channelCounters;
     this.serviceClientFactory = serviceClientFactory;
     this.objectMapper = objectMapper;
@@ -97,7 +100,7 @@ public class DartDataServerQueryHandler implements 
DataServerQueryHandler
   {
     final Query<QueryType> preparedQuery =
         Queries.withSpecificSegments(
-            DataServerQueryHandlerUtils.prepareQuery(query, dataSource),
+            DataServerQueryHandlerUtils.prepareQuery(query, inputNumber, 
dataSourceName),
             requestDescriptor.getSegments()
                              .stream()
                              .map(RichSegmentDescriptor::toPlainDescriptor)
@@ -140,7 +143,7 @@ public class DartDataServerQueryHandler implements 
DataServerQueryHandler
           return new DataServerQueryResult<>(
               Collections.singletonList(yielder),
               Collections.emptyList(),
-              dataSource
+              dataSourceName
           );
         }
     );
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandlerFactory.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandlerFactory.java
index a7fc2004f7b..56334071a77 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandlerFactory.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataServerQueryHandlerFactory.java
@@ -48,13 +48,15 @@ public class DartDataServerQueryHandlerFactory implements 
DataServerQueryHandler
 
   @Override
   public DartDataServerQueryHandler createDataServerQueryHandler(
-      String dataSource,
+      int inputNumber,
+      String dataSourceName,
       ChannelCounters channelCounters,
       DataServerRequestDescriptor requestDescriptor
   )
   {
     return new DartDataServerQueryHandler(
-        dataSource,
+        inputNumber,
+        dataSourceName,
         channelCounters,
         serviceClientFactory,
         objectMapper,
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerFactory.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerFactory.java
index 245c078a1c8..27a7d097c03 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerFactory.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerFactory.java
@@ -28,7 +28,8 @@ import 
org.apache.druid.msq.input.table.DataServerRequestDescriptor;
 public interface DataServerQueryHandlerFactory
 {
   DataServerQueryHandler createDataServerQueryHandler(
-      String dataSource,
+      int inputNumber,
+      String dataSourceName,
       ChannelCounters channelCounters,
       DataServerRequestDescriptor dataServerRequestDescriptor
   );
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerUtils.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerUtils.java
index ce63f8100e1..392e6ee378e 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerUtils.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/DataServerQueryHandlerUtils.java
@@ -20,19 +20,27 @@
 package org.apache.druid.msq.exec;
 
 import org.apache.druid.discovery.DataServerClient;
+import org.apache.druid.error.DruidException;
 import org.apache.druid.java.util.common.guava.Sequence;
 import org.apache.druid.java.util.common.guava.Yielder;
 import org.apache.druid.java.util.common.guava.Yielders;
 import org.apache.druid.msq.counters.ChannelCounters;
+import org.apache.druid.msq.querykit.InputNumberDataSource;
+import org.apache.druid.msq.querykit.RestrictedInputNumberDataSource;
+import org.apache.druid.query.DataSource;
+import org.apache.druid.query.JoinAlgorithm;
 import org.apache.druid.query.Queries;
 import org.apache.druid.query.Query;
+import org.apache.druid.query.RestrictedDataSource;
 import org.apache.druid.query.SegmentDescriptor;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.context.ResponseContext;
+import org.apache.druid.sql.calcite.planner.PlannerContext;
 
 import java.util.Collections;
 import java.util.List;
 import java.util.function.Function;
+import java.util.stream.Collectors;
 
 /**
  * Static utility functions for {@link DataServerQueryHandler} implementations.
@@ -48,17 +56,65 @@ public class DataServerQueryHandlerUtils
    * Performs necessary transforms to a query destined for data servers. Does 
not update the list of segments; callers
    * should do this themselves using {@link 
Queries#withSpecificSegments(Query, List)}.
    *
-   * @param query      the query
-   * @param dataSource datasource name
+   * @param query          the query
+   * @param dataSourceName datasource name
    */
-  public static <R, T extends Query<R>> Query<R> prepareQuery(final T query, 
final String dataSource)
+  public static <R, T extends Query<R>> Query<R> prepareQuery(
+      final T query,
+      final int inputNumber,
+      final String dataSourceName
+  )
   {
     // MSQ changes the datasource to an inputNumber datasource. This needs to 
be changed back for data servers
     // to understand.
+    return query.withDataSource(transformDatasource(query.getDataSource(), 
inputNumber, dataSourceName));
+  }
 
-    // BUG: This transformation is incorrect; see 
https://github.com/apache/druid/issues/18198. It loses decorations
-    // such as join, unnest, etc.
-    return query.withDataSource(new TableDataSource(dataSource));
+  /**
+   * Transforms {@link InputNumberDataSource} and {@link 
RestrictedInputNumberDataSource}, which are only understood
+   * by MSQ tasks, back into {@link TableDataSource} and {@link 
RestrictedDataSource} recursivly.
+   */
+  static DataSource transformDatasource(
+      final DataSource dataSource,
+      final int inputNumber,
+      final String dataSourceName
+  )
+  {
+    if (dataSource instanceof InputNumberDataSource) {
+      InputNumberDataSource numberDataSource = (InputNumberDataSource) 
dataSource;
+      if (numberDataSource.getInputNumber() == inputNumber) {
+        return new TableDataSource(dataSourceName);
+      } else {
+        throw DruidException.forPersona(DruidException.Persona.USER)
+                            .ofCategory(DruidException.Category.UNSUPPORTED)
+                            .build(
+                                "Cannot handle stage with multiple sources 
while querying realtime data. "
+                                + "If using broadcast joins, try setting[%s] 
to[%s] in your query context.",
+                                PlannerContext.CTX_SQL_JOIN_ALGORITHM,
+                                JoinAlgorithm.SORT_MERGE.toString()
+                            );
+      }
+    } else if (dataSource instanceof RestrictedInputNumberDataSource) {
+      RestrictedInputNumberDataSource restrictedDatasource = 
(RestrictedInputNumberDataSource) dataSource;
+      if (restrictedDatasource.getInputNumber() == inputNumber) {
+        return RestrictedDataSource.create(new 
TableDataSource(dataSourceName), restrictedDatasource.getPolicy());
+      } else {
+        throw DruidException.forPersona(DruidException.Persona.USER)
+                            .ofCategory(DruidException.Category.UNSUPPORTED)
+                            .build(
+                                "Cannot handle stage with multiple sources 
while querying realtime data. "
+                                + "If using broadcast joins, try setting[%s] 
to[%s] in your query context.",
+                                PlannerContext.CTX_SQL_JOIN_ALGORITHM,
+                                JoinAlgorithm.SORT_MERGE.toString()
+                            );
+      }
+    } else {
+      List<DataSource> transformed = dataSource.getChildren()
+                                               .stream()
+                                               .map(ds -> 
transformDatasource(ds, inputNumber, dataSourceName))
+                                               .collect(Collectors.toList());
+      return dataSource.withChildren(transformed);
+    }
   }
 
   /**
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandler.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandler.java
index b8b2afd5137..8d4b1c20c30 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandler.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandler.java
@@ -78,7 +78,8 @@ public class IndexerDataServerQueryHandler implements 
DataServerQueryHandler
   private static final Logger log = new 
Logger(IndexerDataServerQueryHandler.class);
   private static final int DEFAULT_NUM_TRIES = 3;
   private static final int PER_SERVER_QUERY_NUM_TRIES = 5;
-  private final String dataSource;
+  private final int inputNumber;
+  private final String dataSourceName;
   private final ChannelCounters channelCounters;
   private final ServiceClientFactory serviceClientFactory;
   private final CoordinatorClient coordinatorClient;
@@ -87,7 +88,8 @@ public class IndexerDataServerQueryHandler implements 
DataServerQueryHandler
   private final DataServerRequestDescriptor dataServerRequestDescriptor;
 
   public IndexerDataServerQueryHandler(
-      String dataSource,
+      int inputNumber,
+      String dataSourceName,
       ChannelCounters channelCounters,
       ServiceClientFactory serviceClientFactory,
       CoordinatorClient coordinatorClient,
@@ -96,7 +98,8 @@ public class IndexerDataServerQueryHandler implements 
DataServerQueryHandler
       DataServerRequestDescriptor dataServerRequestDescriptor
   )
   {
-    this.dataSource = dataSource;
+    this.inputNumber = inputNumber;
+    this.dataSourceName = dataSourceName;
     this.channelCounters = channelCounters;
     this.serviceClientFactory = serviceClientFactory;
     this.coordinatorClient = coordinatorClient;
@@ -140,7 +143,7 @@ public class IndexerDataServerQueryHandler implements 
DataServerQueryHandler
       Closer closer
   )
   {
-    final Query<QueryType> preparedQuery = 
DataServerQueryHandlerUtils.prepareQuery(query, dataSource);
+    final Query<QueryType> preparedQuery = 
DataServerQueryHandlerUtils.prepareQuery(query, inputNumber, dataSourceName);
     final List<Yielder<RowType>> yielders = new ArrayList<>();
     final List<RichSegmentDescriptor> handedOffSegments = new ArrayList<>();
 
@@ -206,7 +209,7 @@ public class IndexerDataServerQueryHandler implements 
DataServerQueryHandler
 
     // Not actually async. The retry logic above is written in synchronous 
fashion. Just return an immediate-future
     // when we actually have all queries issued and all yielders set up.
-    return Futures.immediateFuture(new DataServerQueryResult<>(yielders, 
handedOffSegments, dataSource));
+    return Futures.immediateFuture(new DataServerQueryResult<>(yielders, 
handedOffSegments, dataSourceName));
   }
 
   private <QueryType, RowType> Yielder<RowType> 
fetchRowsFromDataServerInternal(
@@ -286,7 +289,7 @@ public class IndexerDataServerQueryHandler implements 
DataServerQueryHandler
 
     Iterable<ImmutableSegmentLoadInfo> immutableSegmentLoadInfos =
         coordinatorClient.fetchServerViewSegments(
-            dataSource,
+            dataSourceName,
             
richSegmentDescriptors.stream().map(RichSegmentDescriptor::getFullInterval).collect(Collectors.toList())
         );
 
@@ -358,7 +361,7 @@ public class IndexerDataServerQueryHandler implements 
DataServerQueryHandler
 
       for (SegmentDescriptor segmentDescriptor : segmentDescriptors) {
         Boolean wasHandedOff = FutureUtils.get(
-            coordinatorClient.isHandoffComplete(dataSource, segmentDescriptor),
+            coordinatorClient.isHandoffComplete(dataSourceName, 
segmentDescriptor),
             true
         );
         if (Boolean.TRUE.equals(wasHandedOff)) {
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerFactory.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerFactory.java
index 46e06a5484b..cf93c398ba2 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerFactory.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerFactory.java
@@ -52,13 +52,15 @@ public class IndexerDataServerQueryHandlerFactory 
implements DataServerQueryHand
 
   @Override
   public IndexerDataServerQueryHandler createDataServerQueryHandler(
-      String dataSource,
+      int inputNumber,
+      String dataSourceName,
       ChannelCounters channelCounters,
       DataServerRequestDescriptor requestDescriptor
   )
   {
     return new IndexerDataServerQueryHandler(
-        dataSource,
+        inputNumber,
+        dataSourceName,
         channelCounters,
         serviceClientFactory,
         coordinatorClient,
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java
index 2233ad56473..fe59cd2b17b 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java
@@ -85,6 +85,7 @@ public class SegmentsInputSliceReader implements 
InputSliceReader
       Iterator<ReadableInput> dataServerIterator =
           Iterators.transform(
               dataServerIterator(
+                  inputNumber,
                   segmentsInputSlice.getDataSource(),
                   segmentsInputSlice.getServedSegments(),
                   
counters.channel(CounterNames.inputChannel(inputNumber)).setTotalFiles(slice.fileCount())
@@ -118,6 +119,7 @@ public class SegmentsInputSliceReader implements 
InputSliceReader
   }
 
   private Iterator<DataServerQueryHandler> dataServerIterator(
+      final int inputNumber,
       final String dataSource,
       final List<DataServerRequestDescriptor> servedSegments,
       final ChannelCounters channelCounters
@@ -125,6 +127,7 @@ public class SegmentsInputSliceReader implements 
InputSliceReader
   {
     return servedSegments.stream().map(
         dataServerRequestDescriptor -> 
dataServerQueryHandlerFactory.createDataServerQueryHandler(
+            inputNumber,
             dataSource,
             channelCounters,
             dataServerRequestDescriptor
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/DataServerQueryHandlerUtilsTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/DataServerQueryHandlerUtilsTest.java
new file mode 100644
index 00000000000..61832877c40
--- /dev/null
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/DataServerQueryHandlerUtilsTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import org.apache.druid.msq.querykit.InputNumberDataSource;
+import org.apache.druid.msq.querykit.RestrictedInputNumberDataSource;
+import org.apache.druid.query.DataSource;
+import org.apache.druid.query.RestrictedDataSource;
+import org.apache.druid.query.TableDataSource;
+import org.apache.druid.query.policy.NoRestrictionPolicy;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class DataServerQueryHandlerUtilsTest
+{
+  @Test
+  public void testTransformDatasource()
+  {
+    DataSource ds = DataServerQueryHandlerUtils.transformDatasource(new 
InputNumberDataSource(1), 1, "foo");
+    Assertions.assertEquals(ds, TableDataSource.create("foo"));
+  }
+
+  @Test
+  public void testTransformRestrictedDatasource()
+  {
+    DataSource ds = DataServerQueryHandlerUtils.transformDatasource(
+        new RestrictedInputNumberDataSource(
+            1,
+            NoRestrictionPolicy.instance()
+        ),
+        1,
+        "foo"
+    );
+    Assertions.assertEquals(ds, 
RestrictedDataSource.create(TableDataSource.create("foo"), 
NoRestrictionPolicy.instance()));
+  }
+}
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerTest.java
index 2285e49b7fc..d874fa452d7 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerDataServerQueryHandlerTest.java
@@ -40,6 +40,7 @@ import org.apache.druid.msq.input.table.RichSegmentDescriptor;
 import org.apache.druid.msq.querykit.InputNumberDataSource;
 import org.apache.druid.msq.querykit.scan.ScanQueryFrameProcessor;
 import org.apache.druid.msq.util.MultiStageQueryContext;
+import org.apache.druid.query.FilteredDataSource;
 import org.apache.druid.query.MapQueryToolChestWarehouse;
 import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryContexts;
@@ -124,7 +125,7 @@ public class IndexerDataServerQueryHandlerTest
     dataServerClient2 = mock(DataServerClient.class);
     coordinatorClient = mock(CoordinatorClient.class);
     query = newScanQueryBuilder()
-        .dataSource(new InputNumberDataSource(1))
+        .dataSource(FilteredDataSource.create(new InputNumberDataSource(1), 
null))
         .intervals(new 
MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("2003/2004"))))
         .columns("__time", "cnt", "dim1", "dim2", "m1", "m2", "unique_dim1")
         .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
@@ -137,6 +138,7 @@ public class IndexerDataServerQueryHandlerTest
     );
     target = spy(
         new IndexerDataServerQueryHandler(
+            1,
             DATASOURCE1,
             new ChannelCounters(),
             mock(ServiceClientFactory.class),
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java
index b287458cf6b..1d3242b114f 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java
@@ -22,6 +22,7 @@ package org.apache.druid.msq.test;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.ListenableFuture;
 import com.google.inject.Binder;
 import com.google.inject.Inject;
 import com.google.inject.Module;
@@ -38,14 +39,17 @@ import org.apache.druid.guice.annotations.Self;
 import org.apache.druid.indexing.common.SegmentCacheManagerFactory;
 import org.apache.druid.initialization.DruidModule;
 import org.apache.druid.java.util.common.concurrent.Execs;
+import org.apache.druid.java.util.common.guava.Sequence;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.msq.counters.ChannelCounters;
 import org.apache.druid.msq.exec.DataServerQueryHandler;
 import org.apache.druid.msq.exec.DataServerQueryHandlerFactory;
+import org.apache.druid.msq.exec.DataServerQueryResult;
 import org.apache.druid.msq.guice.MSQExternalDataSourceModule;
 import org.apache.druid.msq.guice.MSQIndexingModule;
 import org.apache.druid.msq.querykit.DataSegmentProvider;
 import org.apache.druid.query.ForwardingQueryProcessingPool;
+import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryProcessingPool;
 import org.apache.druid.query.groupby.TestGroupByBuffers;
 import org.apache.druid.segment.CompleteSegment;
@@ -62,7 +66,6 @@ import 
org.apache.druid.server.coordination.NoopDataSegmentAnnouncer;
 import org.apache.druid.sql.calcite.TempDirProducer;
 import org.apache.druid.timeline.SegmentId;
 import org.easymock.EasyMock;
-import org.mockito.Mockito;
 
 import java.io.File;
 import java.util.List;
@@ -70,11 +73,6 @@ import java.util.Set;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.doThrow;
-
 /**
  * Helper class aiding in wiring up the Guice bindings required for MSQ engine 
to work with the Calcite's tests
  */
@@ -190,11 +188,17 @@ public class CalciteMSQTestsHelper
     // Currently, there is no metadata in this test for loaded segments. 
Therefore, this should not be called.
     // In the future, if this needs to be supported, mocks for 
DataServerQueryHandler should be added like
     // org.apache.druid.msq.exec.MSQLoadedSegmentTests.
-    DataServerQueryHandlerFactory mockFactory = 
Mockito.mock(DataServerQueryHandlerFactory.class);
-    DataServerQueryHandler dataServerQueryHandler = 
Mockito.mock(DataServerQueryHandler.class);
-    doThrow(new AssertionError("Test does not support loaded segment query"))
-        .when(dataServerQueryHandler).fetchRowsFromDataServer(any(), any(), 
any());
-    
doReturn(dataServerQueryHandler).when(mockFactory).createDataServerQueryHandler(anyString(),
 any(), any());
-    return mockFactory;
+    return (inputNumber, dataSourceName, channelCounters, 
dataServerRequestDescriptor) -> new DataServerQueryHandler()
+    {
+      @Override
+      public <RowType, QueryType> 
ListenableFuture<DataServerQueryResult<RowType>> fetchRowsFromDataServer(
+          Query<QueryType> query,
+          Function<Sequence<QueryType>, Sequence<RowType>> mappingFunction,
+          Closer closer
+      )
+      {
+        throw new AssertionError("Test does not support loaded segment query");
+      }
+    };
   }
 }
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
index 3deba56ad7d..279f5c751e4 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
@@ -244,9 +244,6 @@ import static 
org.apache.druid.sql.calcite.util.CalciteTests.WIKIPEDIA;
 import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS1;
 import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS2;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 
 /**
@@ -646,11 +643,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
 
   private DataServerQueryHandlerFactory getTestDataServerQueryHandlerFactory()
   {
-    DataServerQueryHandlerFactory mockFactory = 
Mockito.mock(DataServerQueryHandlerFactory.class);
-    doReturn(dataServerQueryHandler)
-        .when(mockFactory)
-        .createDataServerQueryHandler(anyString(), any(), any());
-    return mockFactory;
+    return (inputNumber, dataSourceName, channelCounters, 
dataServerRequestDescriptor) -> dataServerQueryHandler;
   }
 
   protected List<Number> getEmittedMetrics(String metricName, Map<String, 
Object> dimensionFilters)
diff --git 
a/services/src/test/java/org/apache/druid/testing/embedded/EmbeddedClusterApis.java
 
b/services/src/test/java/org/apache/druid/testing/embedded/EmbeddedClusterApis.java
index 349518cf4ef..6e531848f9e 100644
--- 
a/services/src/test/java/org/apache/druid/testing/embedded/EmbeddedClusterApis.java
+++ 
b/services/src/test/java/org/apache/druid/testing/embedded/EmbeddedClusterApis.java
@@ -108,12 +108,22 @@ public class EmbeddedClusterApis
    * throwing an exception upon timeout.
    */
   public void waitForTaskToSucceed(String taskId, EmbeddedOverlord overlord)
+  {
+    waitForTaskToFinish(taskId, overlord);
+    verifyTaskHasStatus(taskId, TaskStatus.success(taskId));
+  }
+
+  /**
+   * Waits for the given task to finish (either successfully or 
unsuccessfully). If the given
+   * {@link EmbeddedOverlord} is not the leader, this method can only return by
+   * throwing an exception upon timeout.
+   */
+  public void waitForTaskToFinish(String taskId, EmbeddedOverlord overlord)
   {
     overlord.latchableEmitter().waitForEvent(
         event -> event.hasMetricName(TaskMetrics.RUN_DURATION)
                       .hasDimension(DruidMetrics.TASK_ID, taskId)
     );
-    verifyTaskHasStatus(taskId, TaskStatus.success(taskId));
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to