This is an automated email from the ASF dual-hosted git repository.

shauryachats pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 065d50c3f77 Extend FUNNEL_COUNT to support multiple CORRELATE_BY 
columns (#18760)
065d50c3f77 is described below

commit 065d50c3f77a4b1a5b2571ea92a1aa5c71036ba3
Author: tarun11Mavani <[email protected]>
AuthorDate: Tue Jun 30 04:00:44 2026 +0530

    Extend FUNNEL_COUNT to support multiple CORRELATE_BY columns (#18760)
    
    * Extend FUNNEL_COUNT to support multiple CORRELATE_BY columns
    
    Enable funnel analysis that tracks users through steps within a composite
    key (e.g., per user per device category) by accepting multiple columns in
    CORRELATE_BY(col1, col2, ...).
    
    The single-key path is preserved as a zero-overhead fast path with separate
    addSingleKey/addMultiKey abstract methods and dedicated aggregation loops,
    ensuring no regression for existing single-column queries.
    
    Multi-key composite ID mapping uses stride-based arithmetic when the product
    of dictionary sizes fits in int, with a HashMap fallback for large key 
spaces.
    
    Co-authored-by: Cursor <[email protected]>
    
    * Remove benchmark file from PR
    
    Benchmark was used for local validation only; not needed in the PR.
    
    Co-authored-by: Cursor <[email protected]>
    
    * Preserve original add() signature for backward compatibility
    
    Keep the original `add(Dictionary, A, int, int)` abstract method unchanged.
    The new multi-key method is added as `addMultiKey(A, int, Dictionary[], 
int[])`.
    
    Co-authored-by: Cursor <[email protected]>
    
    * Add tests for DictIdsWrapper HashMap fallback path and fix 
SortedAggregationResult double-count
    
    - Add DictIdsWrapperTest covering the HashMap fallback path 
(large-cardinality
      composite keys where product of dict sizes exceeds Integer.MAX_VALUE):
      path selection, sequential ID assignment, same-key idempotency,
      key-order sensitivity, and round-trip for 2- and 3-column keys.
      Also covers stride-path reverseCompositeId round-trip.
      Add isHashMapPath() predicate to DictIdsWrapper for test introspection
      (avoids widening _strides visibility).
    
    - Add SortedAggregationResultTest with multi-key extraction scenarios.
    
    - Fix SortedAggregationResult.extractResult(): clear _secondaryKeySteps 
after
      flushMultiKeyGroup() so a second call (defensive) returns zeros rather 
than
      double-counting the last open primary group.
    
    * Clarify hash approximation in BitmapResultExtractionStrategy Javadoc
    
    Add method-level doc on convertCompositeToValueBitmap linking the
    multi-key .hashCode() usage to the existing single-key non-INT
    approximation in convertToValueBitmap.
    
    * refactor(funnel): reduce allocations in sorted multi-key path and bitmap 
extraction
    
    SortedAggregationResult: replace HashMap<IntArrayList, boolean[]> with
    pre-allocated flat arrays and linear scan. Zero allocations in the hot
    loop for typical workloads (1-5 secondary key combos per primary group).
    
    BitmapResultExtractionStrategy: replace toCompositeString().hashCode()
    with direct type-aware hash combining, avoiding StringBuilder/String
    allocation per composite ID during extraction.
    
    ---------
    
    Co-authored-by: Cursor <[email protected]>
---
 .../function/funnel/AggregationStrategy.java       | 181 ++++++++++++++++++---
 .../function/funnel/BitmapAggregationStrategy.java |  10 ++
 .../funnel/BitmapResultExtractionStrategy.java     |  58 ++++++-
 .../function/funnel/DictIdsWrapper.java            | 136 +++++++++++++++-
 .../FunnelCountAggregationFunctionFactory.java     |   3 +
 .../FunnelCountSortedAggregationFunction.java      |  16 +-
 .../funnel/SetResultExtractionStrategy.java        |  30 +++-
 .../function/funnel/SortedAggregationResult.java   | 108 +++++++++++-
 .../function/funnel/SortedAggregationStrategy.java |  12 ++
 .../funnel/ThetaSketchAggregationStrategy.java     |  19 ++-
 .../function/funnel/DictIdsWrapperTest.java        | 128 +++++++++++++++
 .../funnel/SortedAggregationResultTest.java        |  57 +++++++
 .../integration/tests/custom/FunnelCountTest.java  | 103 +++++++++++-
 13 files changed, 811 insertions(+), 50 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
index 99006c102ab..1448ea0b243 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
@@ -37,12 +37,16 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
  * There should be no assumptions beyond segment boundaries, different 
aggregation strategies may be utilized
  * across different segments for a given query.
  *
+ * <p>Supports both single-key and multi-key CORRELATE_BY. The single-key path 
is kept as a zero-overhead fast path
+ * (structurally identical to the original single-column implementation) to 
avoid any regression for existing queries.
+ *
  * @param <A> Aggregation result accumulated across blocks within segment, 
kept by result holder.
  */
 @ThreadSafe
 public abstract class AggregationStrategy<A> {
 
   protected final int _numSteps;
+  protected final int _numCorrelateByKeys;
   private final List<ExpressionContext> _stepExpressions;
   private final List<ExpressionContext> _correlateByExpressions;
   private final ExpressionContext _primaryCorrelationCol;
@@ -52,13 +56,38 @@ public abstract class AggregationStrategy<A> {
     _correlateByExpressions = correlateByExpressions;
     _primaryCorrelationCol = _correlateByExpressions.get(0);
     _numSteps = _stepExpressions.size();
+    _numCorrelateByKeys = _correlateByExpressions.size();
   }
 
   /**
-   * Returns an aggregation result for this aggregation strategy to be kept in 
a result holder (aggregation only).
+   * Creates an aggregation result for single-key correlation.
    */
   abstract A createAggregationResult(Dictionary dictionary);
 
+  /**
+   * Creates an aggregation result for multi-key correlation.
+   */
+  abstract A createAggregationResultMultiKey(Dictionary[] dictionaries);
+
+  public A getAggregationResult(Dictionary dictionary, AggregationResultHolder 
aggregationResultHolder) {
+    A aggResult = aggregationResultHolder.getResult();
+    if (aggResult == null) {
+      aggResult = createAggregationResult(dictionary);
+      aggregationResultHolder.setValue(aggResult);
+    }
+    return aggResult;
+  }
+
+  public A getAggregationResultMultiKey(Dictionary[] dictionaries,
+      AggregationResultHolder aggregationResultHolder) {
+    A aggResult = aggregationResultHolder.getResult();
+    if (aggResult == null) {
+      aggResult = createAggregationResultMultiKey(dictionaries);
+      aggregationResultHolder.setValue(aggResult);
+    }
+    return aggResult;
+  }
+
   public A getAggregationResultGroupBy(Dictionary dictionary, 
GroupByResultHolder groupByResultHolder, int groupKey) {
     A aggResult = groupByResultHolder.getResult(groupKey);
     if (aggResult == null) {
@@ -68,11 +97,12 @@ public abstract class AggregationStrategy<A> {
     return aggResult;
   }
 
-  public A getAggregationResult(Dictionary dictionary, AggregationResultHolder 
aggregationResultHolder) {
-    A aggResult = aggregationResultHolder.getResult();
+  public A getAggregationResultGroupByMultiKey(Dictionary[] dictionaries, 
GroupByResultHolder groupByResultHolder,
+      int groupKey) {
+    A aggResult = groupByResultHolder.getResult(groupKey);
     if (aggResult == null) {
-      aggResult = createAggregationResult(dictionary);
-      aggregationResultHolder.setValue(aggResult);
+      aggResult = createAggregationResultMultiKey(dictionaries);
+      groupByResultHolder.setValueForKey(groupKey, aggResult);
     }
     return aggResult;
   }
@@ -82,10 +112,18 @@ public abstract class AggregationStrategy<A> {
    */
   public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
-    final Dictionary dictionary = getDictionary(blockValSetMap);
-    final int[] correlationIds = getCorrelationIds(blockValSetMap);
     final int[][] steps = getSteps(blockValSetMap);
+    if (_numCorrelateByKeys == 1) {
+      aggregateSingleKey(length, aggregationResultHolder, blockValSetMap, 
steps);
+    } else {
+      aggregateMultiKey(length, aggregationResultHolder, blockValSetMap, 
steps);
+    }
+  }
 
+  private void aggregateSingleKey(int length, AggregationResultHolder 
aggregationResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+    final Dictionary dictionary = getPrimaryDictionary(blockValSetMap);
+    final int[] correlationIds = getPrimaryCorrelationIds(blockValSetMap);
     final A aggResult = getAggregationResult(dictionary, 
aggregationResultHolder);
     for (int i = 0; i < length; i++) {
       for (int n = 0; n < _numSteps; n++) {
@@ -96,20 +134,46 @@ public abstract class AggregationStrategy<A> {
     }
   }
 
+  private void aggregateMultiKey(int length, AggregationResultHolder 
aggregationResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+    final Dictionary[] dictionaries = getAllDictionaries(blockValSetMap);
+    final int[][] allCorrelationIds = getAllCorrelationDictIds(blockValSetMap);
+    final A aggResult = getAggregationResultMultiKey(dictionaries, 
aggregationResultHolder);
+    final int[] rowDictIds = new int[_numCorrelateByKeys];
+    for (int i = 0; i < length; i++) {
+      for (int k = 0; k < _numCorrelateByKeys; k++) {
+        rowDictIds[k] = allCorrelationIds[k][i];
+      }
+      for (int n = 0; n < _numSteps; n++) {
+        if (steps[n][i] > 0) {
+          addMultiKey(aggResult, n, dictionaries, rowDictIds);
+        }
+      }
+    }
+  }
+
   /**
    * Performs aggregation on the given group key array and block value sets 
(aggregation group-by on single-value
    * columns).
    */
   public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
-    final Dictionary dictionary = getDictionary(blockValSetMap);
-    final int[] correlationIds = getCorrelationIds(blockValSetMap);
     final int[][] steps = getSteps(blockValSetMap);
+    if (_numCorrelateByKeys == 1) {
+      aggregateGroupBySVSingleKey(length, groupKeyArray, groupByResultHolder, 
blockValSetMap, steps);
+    } else {
+      aggregateGroupBySVMultiKey(length, groupKeyArray, groupByResultHolder, 
blockValSetMap, steps);
+    }
+  }
 
+  private void aggregateGroupBySVSingleKey(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+    final Dictionary dictionary = getPrimaryDictionary(blockValSetMap);
+    final int[] correlationIds = getPrimaryCorrelationIds(blockValSetMap);
     for (int i = 0; i < length; i++) {
+      final int groupKey = groupKeyArray[i];
+      final A aggResult = getAggregationResultGroupBy(dictionary, 
groupByResultHolder, groupKey);
       for (int n = 0; n < _numSteps; n++) {
-        final int groupKey = groupKeyArray[i];
-        final A aggResult = getAggregationResultGroupBy(dictionary, 
groupByResultHolder, groupKey);
         if (steps[n][i] > 0) {
           add(dictionary, aggResult, n, correlationIds[i]);
         }
@@ -117,20 +181,47 @@ public abstract class AggregationStrategy<A> {
     }
   }
 
+  private void aggregateGroupBySVMultiKey(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+    final Dictionary[] dictionaries = getAllDictionaries(blockValSetMap);
+    final int[][] allCorrelationIds = getAllCorrelationDictIds(blockValSetMap);
+    final int[] rowDictIds = new int[_numCorrelateByKeys];
+    for (int i = 0; i < length; i++) {
+      for (int k = 0; k < _numCorrelateByKeys; k++) {
+        rowDictIds[k] = allCorrelationIds[k][i];
+      }
+      final int groupKey = groupKeyArray[i];
+      final A aggResult = getAggregationResultGroupByMultiKey(dictionaries, 
groupByResultHolder, groupKey);
+      for (int n = 0; n < _numSteps; n++) {
+        if (steps[n][i] > 0) {
+          addMultiKey(aggResult, n, dictionaries, rowDictIds);
+        }
+      }
+    }
+  }
+
   /**
    * Performs aggregation on the given group keys array and block value sets 
(aggregation group-by on multi-value
    * columns).
    */
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
-    final Dictionary dictionary = getDictionary(blockValSetMap);
-    final int[] correlationIds = getCorrelationIds(blockValSetMap);
     final int[][] steps = getSteps(blockValSetMap);
+    if (_numCorrelateByKeys == 1) {
+      aggregateGroupByMVSingleKey(length, groupKeysArray, groupByResultHolder, 
blockValSetMap, steps);
+    } else {
+      aggregateGroupByMVMultiKey(length, groupKeysArray, groupByResultHolder, 
blockValSetMap, steps);
+    }
+  }
 
+  private void aggregateGroupByMVSingleKey(int length, int[][] groupKeysArray,
+      GroupByResultHolder groupByResultHolder, Map<ExpressionContext, 
BlockValSet> blockValSetMap, int[][] steps) {
+    final Dictionary dictionary = getPrimaryDictionary(blockValSetMap);
+    final int[] correlationIds = getPrimaryCorrelationIds(blockValSetMap);
     for (int i = 0; i < length; i++) {
-      for (int n = 0; n < _numSteps; n++) {
-        for (int groupKey : groupKeysArray[i]) {
-          final A aggResult = getAggregationResultGroupBy(dictionary, 
groupByResultHolder, groupKey);
+      for (int groupKey : groupKeysArray[i]) {
+        final A aggResult = getAggregationResultGroupBy(dictionary, 
groupByResultHolder, groupKey);
+        for (int n = 0; n < _numSteps; n++) {
           if (steps[n][i] > 0) {
             add(dictionary, aggResult, n, correlationIds[i]);
           }
@@ -139,26 +230,74 @@ public abstract class AggregationStrategy<A> {
     }
   }
 
+  private void aggregateGroupByMVMultiKey(int length, int[][] groupKeysArray,
+      GroupByResultHolder groupByResultHolder, Map<ExpressionContext, 
BlockValSet> blockValSetMap, int[][] steps) {
+    final Dictionary[] dictionaries = getAllDictionaries(blockValSetMap);
+    final int[][] allCorrelationIds = getAllCorrelationDictIds(blockValSetMap);
+    final int[] rowDictIds = new int[_numCorrelateByKeys];
+    for (int i = 0; i < length; i++) {
+      for (int k = 0; k < _numCorrelateByKeys; k++) {
+        rowDictIds[k] = allCorrelationIds[k][i];
+      }
+      for (int groupKey : groupKeysArray[i]) {
+        final A aggResult = getAggregationResultGroupByMultiKey(dictionaries, 
groupByResultHolder, groupKey);
+        for (int n = 0; n < _numSteps; n++) {
+          if (steps[n][i] > 0) {
+            addMultiKey(aggResult, n, dictionaries, rowDictIds);
+          }
+        }
+      }
+    }
+  }
+
   /**
    * Adds a correlation id to the aggregation counter for a given step in the 
funnel.
    */
   abstract void add(Dictionary dictionary, A aggResult, int step, int 
correlationId);
 
-  private Dictionary getDictionary(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
+  /**
+   * Adds a row's composite correlation identity to the aggregation counter 
for a given step (multi-key path).
+   *
+   * @param aggResult          the aggregation result to update
+   * @param step               the funnel step index
+   * @param dictionaries       one dictionary per correlate-by column
+   * @param correlationDictIds one dictionary ID per correlate-by column for 
the current row
+   *                           (this array is reused across rows; 
implementations must not hold a reference)
+   */
+  abstract void addMultiKey(A aggResult, int step, Dictionary[] dictionaries, 
int[] correlationDictIds);
+
+  Dictionary getPrimaryDictionary(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
     final BlockValSet primaryCorrelationValSet = 
blockValSetMap.get(_primaryCorrelationCol);
-    // FUNNELCOUNT requires dict-id reads from the forward index; a column 
with EncodingType.RAW + dictionaryIndex
-    // exposes a Dictionary but BlockValSet#getDictionaryIdsSV throws on the 
RAW forward index. Gate on the
-    // explicit forward-index encoding flag rather than dictionary nullness 
alone.
     Preconditions.checkArgument(primaryCorrelationValSet.isDictionaryEncoded(),
         "CORRELATE_BY column in FUNNELCOUNT aggregation function not 
supported, please use a dictionary encoded "
             + "column.");
     return primaryCorrelationValSet.getDictionary();
   }
 
-  private int[] getCorrelationIds(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
+  private Dictionary[] getAllDictionaries(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
+    Dictionary[] dictionaries = new Dictionary[_numCorrelateByKeys];
+    for (int k = 0; k < _numCorrelateByKeys; k++) {
+      BlockValSet valSet = blockValSetMap.get(_correlateByExpressions.get(k));
+      Preconditions.checkArgument(valSet.isDictionaryEncoded(),
+          "CORRELATE_BY column in FUNNELCOUNT aggregation function not 
supported, please use a dictionary encoded "
+              + "column.");
+      dictionaries[k] = valSet.getDictionary();
+    }
+    return dictionaries;
+  }
+
+  private int[] getPrimaryCorrelationIds(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
     return blockValSetMap.get(_primaryCorrelationCol).getDictionaryIdsSV();
   }
 
+  private int[][] getAllCorrelationDictIds(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
+    int[][] allIds = new int[_numCorrelateByKeys][];
+    for (int k = 0; k < _numCorrelateByKeys; k++) {
+      allIds[k] = 
blockValSetMap.get(_correlateByExpressions.get(k)).getDictionaryIdsSV();
+    }
+    return allIds;
+  }
+
   private int[][] getSteps(Map<ExpressionContext, BlockValSet> blockValSetMap) 
{
     final int[][] steps = new int[_numSteps][];
     for (int n = 0; n < _numSteps; n++) {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
index f726d936205..c0f3019fa1c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
@@ -37,8 +37,18 @@ class BitmapAggregationStrategy extends 
AggregationStrategy<DictIdsWrapper> {
     return new DictIdsWrapper(_numSteps, dictionary);
   }
 
+  @Override
+  public DictIdsWrapper createAggregationResultMultiKey(Dictionary[] 
dictionaries) {
+    return new DictIdsWrapper(_numSteps, dictionaries);
+  }
+
   @Override
   protected void add(Dictionary dictionary, DictIdsWrapper dictIdsWrapper, int 
step, int correlationId) {
     dictIdsWrapper._stepsBitmaps[step].add(correlationId);
   }
+
+  @Override
+  void addMultiKey(DictIdsWrapper dictIdsWrapper, int step, Dictionary[] 
dictionaries, int[] correlationDictIds) {
+    
dictIdsWrapper._stepsBitmaps[step].add(dictIdsWrapper.getCompositeCorrelationId(correlationDictIds));
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
index 1611b8dae8f..4b0373123df 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
@@ -26,6 +26,13 @@ import org.roaringbitmap.PeekableIntIterator;
 import org.roaringbitmap.RoaringBitmap;
 
 
+/**
+ * Extracts intermediate bitmap results for cross-segment merging.
+ *
+ * <p>The bitmap strategy stores entities as 32-bit hash codes in a {@link 
RoaringBitmap}. For single-key INT
+ * columns, the actual int values are stored directly (exact). For other 
single-key types and all multi-key
+ * composites, hash codes are used (approximate — hash collisions can cause 
under-counting).
+ */
 class BitmapResultExtractionStrategy implements 
ResultExtractionStrategy<DictIdsWrapper, List<RoaringBitmap>> {
   protected final int _numSteps;
 
@@ -42,14 +49,59 @@ class BitmapResultExtractionStrategy implements 
ResultExtractionStrategy<DictIds
       }
       return result;
     }
-    Dictionary dictionary = dictIdsWrapper._dictionary;
     List<RoaringBitmap> result = new ArrayList<>(_numSteps);
-    for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
-      result.add(convertToValueBitmap(dictionary, dictIdBitmap));
+    if (dictIdsWrapper.isMultiKey()) {
+      for (RoaringBitmap compositeIdBitmap : dictIdsWrapper._stepsBitmaps) {
+        result.add(convertCompositeToValueBitmap(dictIdsWrapper, 
compositeIdBitmap));
+      }
+    } else {
+      Dictionary dictionary = dictIdsWrapper._dictionaries[0];
+      for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
+        result.add(convertToValueBitmap(dictionary, dictIdBitmap));
+      }
     }
     return result;
   }
 
+  /// Converts segment-local composite dictionary IDs to hash-coded value 
bitmaps for cross-segment merging.
+  /// Combines per-column value hashes directly — no string allocation. Same 
approximation as the
+  /// single-key non-INT path in {@link #convertToValueBitmap}: hash 
collisions may cause under-counting.
+  private RoaringBitmap convertCompositeToValueBitmap(DictIdsWrapper wrapper, 
RoaringBitmap compositeIdBitmap) {
+    RoaringBitmap valueBitmap = new RoaringBitmap();
+    PeekableIntIterator iterator = compositeIdBitmap.getIntIterator();
+    int numKeys = wrapper._dictionaries.length;
+    int[] dictIds = new int[numKeys];
+    while (iterator.hasNext()) {
+      wrapper.reverseCompositeId(iterator.next(), dictIds);
+      int hash = 1;
+      for (int k = 0; k < numKeys; k++) {
+        hash = 31 * hash + valueHashCode(wrapper._dictionaries[k], dictIds[k]);
+      }
+      valueBitmap.add(hash);
+    }
+    return valueBitmap;
+  }
+
+  /// Returns the hash code of a dictionary value using its native type, 
avoiding string conversion
+  /// for numeric types.
+  private static int valueHashCode(Dictionary dictionary, int dictId) {
+    switch (dictionary.getValueType()) {
+      case INT:
+        return Integer.hashCode(dictionary.getIntValue(dictId));
+      case LONG:
+        return Long.hashCode(dictionary.getLongValue(dictId));
+      case FLOAT:
+        return Float.hashCode(dictionary.getFloatValue(dictId));
+      case DOUBLE:
+        return Double.hashCode(dictionary.getDoubleValue(dictId));
+      case STRING:
+        return dictionary.getStringValue(dictId).hashCode();
+      default:
+        throw new IllegalArgumentException("Illegal data type for FUNNEL_COUNT 
aggregation function: "
+            + dictionary.getValueType());
+    }
+  }
+
   /**
    * Helper method to read dictionary and convert dictionary ids to hash code 
of the values for dictionary-encoded
    * expression.
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
index c09d0128f29..a778d48319b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
@@ -18,19 +18,151 @@
  */
 package org.apache.pinot.core.query.aggregation.function.funnel;
 
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 import org.apache.pinot.segment.spi.index.reader.Dictionary;
 import org.roaringbitmap.RoaringBitmap;
 
 
+/**
+ * Holds per-step RoaringBitmaps keyed by correlation dictionary IDs.
+ *
+ * <p>For single-key CORRELATE_BY, stores raw dictionary IDs directly in the 
bitmaps (compact, fits in one
+ * RoaringBitmap container for typical segment sizes).
+ *
+ * <p>For multi-key CORRELATE_BY, composite IDs are assigned via stride-based 
arithmetic (when the combined key
+ * space fits in int) or a HashMap fallback for large key spaces.
+ */
 final class DictIdsWrapper {
-  final Dictionary _dictionary;
+  final Dictionary[] _dictionaries;
   final RoaringBitmap[] _stepsBitmaps;
 
+  // Stride-based composite mapping (non-null only for multi-key when product 
of dict sizes fits in int)
+  private final int[] _strides;
+
+  // HashMap-based composite mapping (non-null only for multi-key when stride 
overflows int)
+  private final Map<IntArrayList, Integer> _compositeKeyMap;
+  private final List<int[]> _compositeKeyReverse;
+  private final IntArrayList _lookupKey;
+
   DictIdsWrapper(int numSteps, Dictionary dictionary) {
-    _dictionary = dictionary;
+    _dictionaries = new Dictionary[]{dictionary};
     _stepsBitmaps = new RoaringBitmap[numSteps];
     for (int n = 0; n < numSteps; n++) {
       _stepsBitmaps[n] = new RoaringBitmap();
     }
+    _strides = null;
+    _compositeKeyMap = null;
+    _compositeKeyReverse = null;
+    _lookupKey = null;
+  }
+
+  DictIdsWrapper(int numSteps, Dictionary[] dictionaries) {
+    _dictionaries = dictionaries;
+    _stepsBitmaps = new RoaringBitmap[numSteps];
+    for (int n = 0; n < numSteps; n++) {
+      _stepsBitmaps[n] = new RoaringBitmap();
+    }
+
+    if (dictionaries.length > 1) {
+      long totalSpace = 1;
+      boolean fitsInInt = true;
+      for (Dictionary d : dictionaries) {
+        totalSpace *= d.length();
+        if (totalSpace > Integer.MAX_VALUE) {
+          fitsInInt = false;
+          break;
+        }
+      }
+
+      if (fitsInInt) {
+        _strides = new int[dictionaries.length];
+        _strides[dictionaries.length - 1] = 1;
+        for (int k = dictionaries.length - 2; k >= 0; k--) {
+          _strides[k] = _strides[k + 1] * dictionaries[k + 1].length();
+        }
+        _compositeKeyMap = null;
+        _compositeKeyReverse = null;
+        _lookupKey = null;
+      } else {
+        _strides = null;
+        _compositeKeyMap = new HashMap<>();
+        _compositeKeyReverse = new ArrayList<>();
+        _lookupKey = new IntArrayList(dictionaries.length);
+      }
+    } else {
+      _strides = null;
+      _compositeKeyMap = null;
+      _compositeKeyReverse = null;
+      _lookupKey = null;
+    }
+  }
+
+  boolean isMultiKey() {
+    return _dictionaries.length > 1;
+  }
+
+  boolean isHashMapPath() {
+    return _compositeKeyMap != null;
+  }
+
+  /**
+   * Maps a tuple of per-column dictionary IDs to a single composite int 
suitable for RoaringBitmap.
+   * Only used for multi-key; for single-key, callers should add the dictId 
directly.
+   */
+  int getCompositeCorrelationId(int[] dictIds) {
+    if (_strides != null) {
+      int id = 0;
+      for (int k = 0; k < dictIds.length; k++) {
+        id += dictIds[k] * _strides[k];
+      }
+      return id;
+    }
+    _lookupKey.clear();
+    for (int dictId : dictIds) {
+      _lookupKey.add(dictId);
+    }
+    Integer existingId = _compositeKeyMap.get(_lookupKey);
+    if (existingId != null) {
+      return existingId;
+    }
+    IntArrayList insertKey = new IntArrayList(dictIds);
+    int id = _compositeKeyReverse.size();
+    _compositeKeyMap.put(insertKey, id);
+    _compositeKeyReverse.add(dictIds.clone());
+    return id;
+  }
+
+  /**
+   * Builds a collision-free composite string from dictionary values using 
length-prefix encoding.
+   * Each component is encoded as {@code length:value}, e.g. ("alice", "home") 
becomes "5:alice4:home".
+   */
+  static String toCompositeString(Dictionary[] dictionaries, int[] dictIds) {
+    StringBuilder sb = new StringBuilder();
+    for (int k = 0; k < dictionaries.length; k++) {
+      String val = dictionaries[k].getStringValue(dictIds[k]);
+      sb.append(val.length()).append(':').append(val);
+    }
+    return sb.toString();
+  }
+
+  /**
+   * Reverse-maps a composite ID back to per-column dictionary IDs.
+   */
+  void reverseCompositeId(int compositeId, int[] outDictIds) {
+    if (_strides != null) {
+      int remaining = compositeId;
+      for (int k = 0; k < outDictIds.length - 1; k++) {
+        outDictIds[k] = remaining / _strides[k];
+        remaining %= _strides[k];
+      }
+      outDictIds[outDictIds.length - 1] = remaining;
+      return;
+    }
+    int[] stored = _compositeKeyReverse.get(compositeId);
+    System.arraycopy(stored, 0, outDictIds, 0, outDictIds.length);
   }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
index 5d0fb1eeb85..91e9bfca51a 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
@@ -168,6 +168,9 @@ public class FunnelCountAggregationFunctionFactory 
implements Supplier<Aggregati
 
   ResultExtractionStrategy<DictIdsWrapper, List<Long>> 
bitmapPartitionedResultExtractionStrategy() {
     final MergeStrategy<List<RoaringBitmap>> bitmapMergeStrategy = 
bitmapMergeStrategy();
+    // For partitioned mode, each segment is self-contained: every row for a 
given correlation key
+    // appears in exactly one segment. Therefore we can count bitmap 
cardinality directly without
+    // converting segment-local composite IDs to global values — they will 
never be merged across segments.
     return dictIdsWrapper -> {
       if (dictIdsWrapper == null) {
         return Collections.nCopies(_numSteps, 0L);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
index ac39461cef5..a86c0e22e85 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
@@ -34,6 +34,8 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
  * It leverages a more efficient counting strategy for segments sorted by 
correlate_by column, falls back to a regular
  * counting strategy for unsorted segments (e.g. uncommitted segments).
  *
+ * <p>For multi-key correlate-by, the sorted/partitioned optimization applies 
to the first (primary) column only.
+ *
  * Example:
  *   SELECT
  *    dateTrunc('day', timestamp) AS ts,
@@ -59,14 +61,14 @@ public class FunnelCountSortedAggregationFunction<A> 
extends FunnelCountAggregat
     super(expressions, stepExpressions, correlateByExpressions, 
aggregationStrategy, resultExtractionStrategy,
         mergeStrategy);
     _sortedAggregationStrategy = new 
SortedAggregationStrategy(stepExpressions, correlateByExpressions);
-    _sortedResultExtractionStrategy = SortedAggregationResult::extractResult;;
+    _sortedResultExtractionStrategy = SortedAggregationResult::extractResult;
     _primaryCorrelationCol = correlateByExpressions.get(0);
   }
 
   @Override
   public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
-    if (isSortedDictionary(blockValSetMap)) {
+    if (isPrimarySortedDictionary(blockValSetMap)) {
       _sortedAggregationStrategy.aggregate(length, aggregationResultHolder, 
blockValSetMap);
     } else {
       super.aggregate(length, aggregationResultHolder, blockValSetMap);
@@ -76,7 +78,7 @@ public class FunnelCountSortedAggregationFunction<A> extends 
FunnelCountAggregat
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
-    if (isSortedDictionary(blockValSetMap)) {
+    if (isPrimarySortedDictionary(blockValSetMap)) {
       _sortedAggregationStrategy.aggregateGroupBySV(length, groupKeyArray, 
groupByResultHolder, blockValSetMap);
     } else {
       super.aggregateGroupBySV(length, groupKeyArray, groupByResultHolder, 
blockValSetMap);
@@ -86,7 +88,7 @@ public class FunnelCountSortedAggregationFunction<A> extends 
FunnelCountAggregat
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
-    if (isSortedDictionary(blockValSetMap)) {
+    if (isPrimarySortedDictionary(blockValSetMap)) {
       _sortedAggregationStrategy.aggregateGroupByMV(length, groupKeysArray, 
groupByResultHolder, blockValSetMap);
     } else {
       super.aggregateGroupByMV(length, groupKeysArray, groupByResultHolder, 
blockValSetMap);
@@ -111,15 +113,15 @@ public class FunnelCountSortedAggregationFunction<A> 
extends FunnelCountAggregat
     }
   }
 
-  private boolean isSortedDictionary(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
-    return getDictionary(blockValSetMap).isSorted();
+  private boolean isPrimarySortedDictionary(Map<ExpressionContext, 
BlockValSet> blockValSetMap) {
+    return getPrimaryDictionary(blockValSetMap).isSorted();
   }
 
   private boolean isSortedAggResult(Object aggResult) {
     return aggResult instanceof SortedAggregationResult;
   }
 
-  private Dictionary getDictionary(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
+  private Dictionary getPrimaryDictionary(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
     final Dictionary primaryCorrelationDictionary = 
blockValSetMap.get(_primaryCorrelationCol).getDictionary();
     Preconditions.checkArgument(primaryCorrelationDictionary != null,
         "CORRELATE_BY column in FUNNELCOUNT aggregation function not supported 
for sorted setting, "
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
index fad2bbf033a..675288b43c5 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
@@ -33,7 +33,10 @@ import org.roaringbitmap.RoaringBitmap;
 
 
 /**
- * Aggregation strategy leveraging set algebra (unions/intersections).
+ * Extracts intermediate set results for cross-segment merging.
+ *
+ * <p>For single-key, converts dictionary IDs to typed value sets. For 
multi-key, converts composite IDs
+ * to length-prefix-encoded composite strings, producing a {@code Set<String>} 
per step.
  */
 class SetResultExtractionStrategy implements 
ResultExtractionStrategy<DictIdsWrapper, List<Set>> {
   protected final int _numSteps;
@@ -51,14 +54,33 @@ class SetResultExtractionStrategy implements 
ResultExtractionStrategy<DictIdsWra
       }
       return result;
     }
-    Dictionary dictionary = dictIdsWrapper._dictionary;
     List<Set> result = new ArrayList<>(_numSteps);
-    for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
-      result.add(convertToValueSet(dictionary, dictIdBitmap));
+    if (dictIdsWrapper.isMultiKey()) {
+      for (RoaringBitmap compositeIdBitmap : dictIdsWrapper._stepsBitmaps) {
+        result.add(convertCompositeToValueSet(dictIdsWrapper, 
compositeIdBitmap));
+      }
+    } else {
+      Dictionary dictionary = dictIdsWrapper._dictionaries[0];
+      for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
+        result.add(convertToValueSet(dictionary, dictIdBitmap));
+      }
     }
     return result;
   }
 
+  private Set<String> convertCompositeToValueSet(DictIdsWrapper wrapper, 
RoaringBitmap compositeIdBitmap) {
+    int numValues = compositeIdBitmap.getCardinality();
+    int numKeys = wrapper._dictionaries.length;
+    int[] dictIds = new int[numKeys];
+    ObjectOpenHashSet<String> stringSet = new ObjectOpenHashSet<>(numValues);
+    PeekableIntIterator iterator = compositeIdBitmap.getIntIterator();
+    while (iterator.hasNext()) {
+      wrapper.reverseCompositeId(iterator.next(), dictIds);
+      stringSet.add(DictIdsWrapper.toCompositeString(wrapper._dictionaries, 
dictIds));
+    }
+    return stringSet;
+  }
+
   private Set convertToValueSet(Dictionary dictionary, RoaringBitmap 
dictIdBitmap) {
     int numValues = dictIdBitmap.getCardinality();
     PeekableIntIterator iterator = dictIdBitmap.getIntIterator();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
index eb773eac7ed..cf4bb2aa05a 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
@@ -19,29 +19,49 @@
 package org.apache.pinot.core.query.aggregation.function.funnel;
 
 import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.util.Arrays;
 
 
 /**
  * Aggregation result data structure leveraged by sorted aggregation strategy.
+ *
+ * <p>For single-key, uses simple last-ID tracking since data is sorted by the 
correlation column.
+ * For multi-key, data is sorted by the primary (first) correlation column 
only; secondary keys
+ * are tracked via pre-allocated flat arrays within each primary-key group.
  */
 class SortedAggregationResult {
+  private static final int INITIAL_CAPACITY = 8;
+
   final int _numSteps;
   final long[] _stepCounters;
+  private final int _numKeys;
+
+  // Single-key tracking
   final boolean[] _correlatedSteps;
   int _lastCorrelationId = Integer.MIN_VALUE;
 
+  // Multi-key tracking — flat arrays, pre-allocated once and reused across 
groups
+  private int _lastPrimaryId = Integer.MIN_VALUE;
+  private int[][] _entryKeys;
+  private boolean[][] _entrySteps;
+  private int _entryCount;
+
   SortedAggregationResult(int numSteps) {
+    this(numSteps, 1);
+  }
+
+  SortedAggregationResult(int numSteps, int numKeys) {
     _numSteps = numSteps;
-    _stepCounters = new long[_numSteps];
-    _correlatedSteps = new boolean[_numSteps];
+    _numKeys = numKeys;
+    _stepCounters = new long[numSteps];
+    _correlatedSteps = numKeys == 1 ? new boolean[numSteps] : null;
+    _entryKeys = numKeys > 1 ? new int[INITIAL_CAPACITY][numKeys] : null;
+    _entrySteps = numKeys > 1 ? new boolean[INITIAL_CAPACITY][numSteps] : null;
   }
 
   public void add(int step, int correlationId) {
     if (correlationId != _lastCorrelationId) {
-      // End of correlation group, calculate funnel conversion counts
       incrStepCounters();
-
-      // initialize next correlation group
       for (int n = 0; n < _numSteps; n++) {
         _correlatedSteps[n] = false;
       }
@@ -50,7 +70,74 @@ class SortedAggregationResult {
     _correlatedSteps[step] = true;
   }
 
+  /**
+   * Multi-key add. Data must be sorted by correlationIds[0] (primary key).
+   * Secondary keys are tracked via linear scan over pre-allocated flat arrays.
+   *
+   * <p>The full correlationIds array (including the primary key at index 0) 
is used as the
+   * lookup key. The primary key is the same for every entry within a group, 
so including it
+   * is redundant but harmless — it avoids the cost of copying a sub-array.
+   */
+  public void addMultiKey(int step, int[] correlationIds) {
+    int primaryId = correlationIds[0];
+    if (primaryId != _lastPrimaryId) {
+      flushMultiKeyGroup();
+      _lastPrimaryId = primaryId;
+      _entryCount = 0;
+    }
+
+    for (int i = 0; i < _entryCount; i++) {
+      if (keysMatch(_entryKeys[i], correlationIds)) {
+        _entrySteps[i][step] = true;
+        return;
+      }
+    }
+
+    ensureCapacity();
+    System.arraycopy(correlationIds, 0, _entryKeys[_entryCount], 0, _numKeys);
+    Arrays.fill(_entrySteps[_entryCount], false);
+    _entrySteps[_entryCount][step] = true;
+    _entryCount++;
+  }
+
+  private boolean keysMatch(int[] stored, int[] incoming) {
+    for (int i = 0; i < _numKeys; i++) {
+      if (stored[i] != incoming[i]) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  private void ensureCapacity() {
+    if (_entryCount < _entryKeys.length) {
+      return;
+    }
+    int oldCap = _entryKeys.length;
+    int newCap = oldCap * 2;
+    _entryKeys = Arrays.copyOf(_entryKeys, newCap);
+    _entrySteps = Arrays.copyOf(_entrySteps, newCap);
+    for (int i = oldCap; i < newCap; i++) {
+      _entryKeys[i] = new int[_numKeys];
+      _entrySteps[i] = new boolean[_numSteps];
+    }
+  }
+
+  private void flushMultiKeyGroup() {
+    for (int i = 0; i < _entryCount; i++) {
+      for (int n = 0; n < _numSteps; n++) {
+        if (!_entrySteps[i][n]) {
+          break;
+        }
+        _stepCounters[n]++;
+      }
+    }
+  }
+
   void incrStepCounters() {
+    if (_correlatedSteps == null) {
+      return;
+    }
     for (int n = 0; n < _numSteps; n++) {
       if (!_correlatedSteps[n]) {
         break;
@@ -59,9 +146,16 @@ class SortedAggregationResult {
     }
   }
 
+  /**
+   * Extracts the final funnel result. Must be called exactly once.
+   */
   public LongArrayList extractResult() {
-    // count last correlation id left open
-    incrStepCounters();
+    if (_numKeys > 1) {
+      flushMultiKeyGroup();
+      _entryCount = 0;
+    } else {
+      incrStepCounters();
+    }
     return LongArrayList.wrap(_stepCounters);
   }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
index 533d8723a74..7668e7a72bb 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
@@ -25,6 +25,8 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
 
 /**
  * Aggregation strategy for segments partitioned and sorted by the main 
correlation column.
+ * For multi-key correlate-by, data must be sorted by the first (primary) 
column; secondary
+ * keys are handled within each primary-key group by {@link 
SortedAggregationResult}.
  */
 class SortedAggregationStrategy extends 
AggregationStrategy<SortedAggregationResult> {
   public SortedAggregationStrategy(List<ExpressionContext> stepExpressions,
@@ -37,8 +39,18 @@ class SortedAggregationStrategy extends 
AggregationStrategy<SortedAggregationRes
     return new SortedAggregationResult(_numSteps);
   }
 
+  @Override
+  public SortedAggregationResult createAggregationResultMultiKey(Dictionary[] 
dictionaries) {
+    return new SortedAggregationResult(_numSteps, dictionaries.length);
+  }
+
   @Override
   void add(Dictionary dictionary, SortedAggregationResult aggResult, int step, 
int correlationId) {
     aggResult.add(step, correlationId);
   }
+
+  @Override
+  void addMultiKey(SortedAggregationResult aggResult, int step, Dictionary[] 
dictionaries, int[] correlationDictIds) {
+    aggResult.addMultiKey(step, correlationDictIds);
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
index a2ac25f8677..da16056705b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
@@ -46,6 +46,15 @@ class ThetaSketchAggregationStrategy extends 
AggregationStrategy<UpdateSketch[]>
     return stepsSketches;
   }
 
+  @Override
+  public UpdateSketch[] createAggregationResultMultiKey(Dictionary[] 
dictionaries) {
+    final UpdateSketch[] stepsSketches = new UpdateSketch[_numSteps];
+    for (int n = 0; n < _numSteps; n++) {
+      stepsSketches[n] = _updateSketchBuilder.build();
+    }
+    return stepsSketches;
+  }
+
   @Override
   void add(Dictionary dictionary, UpdateSketch[] stepsSketches, int step, int 
correlationId) {
     final UpdateSketch sketch = stepsSketches[step];
@@ -66,8 +75,14 @@ class ThetaSketchAggregationStrategy extends 
AggregationStrategy<UpdateSketch[]>
         sketch.update(dictionary.getStringValue(correlationId));
         break;
       default:
-        throw new IllegalStateException("Illegal CORRELATED_BY column data 
type for FUNNEL_COUNT aggregation function: "
-            + dictionary.getValueType());
+        throw new IllegalStateException(
+            "Illegal CORRELATED_BY column data type for FUNNEL_COUNT 
aggregation function: "
+                + dictionary.getValueType());
     }
   }
+
+  @Override
+  void addMultiKey(UpdateSketch[] stepsSketches, int step, Dictionary[] 
dictionaries, int[] correlationDictIds) {
+    stepsSketches[step].update(DictIdsWrapper.toCompositeString(dictionaries, 
correlationDictIds));
+  }
 }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapperTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapperTest.java
new file mode 100644
index 00000000000..aba3ab13fd7
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapperTest.java
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function.funnel;
+
+import java.util.Arrays;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+
+public class DictIdsWrapperTest {
+
+  // Two dicts of 100_000 each → product 10^10 > Integer.MAX_VALUE → HashMap 
path
+  private static final int LARGE_DICT_SIZE = 100_000;
+
+  private static Dictionary mockDict(int size) {
+    Dictionary d = mock(Dictionary.class);
+    when(d.length()).thenReturn(size);
+    return d;
+  }
+
+  private static Dictionary[] largeDicts(int count) {
+    Dictionary[] dicts = new Dictionary[count];
+    for (int i = 0; i < count; i++) {
+      dicts[i] = mockDict(LARGE_DICT_SIZE);
+    }
+    return dicts;
+  }
+
+  // ── Single-key constructor ───────────────────────────────────────────────
+
+  @Test
+  public void testSingleKeyNotMultiKey() {
+    DictIdsWrapper wrapper = new DictIdsWrapper(2, mockDict(100));
+    Assert.assertFalse(wrapper.isMultiKey());
+    Assert.assertFalse(wrapper.isHashMapPath());
+  }
+
+  // ── HashMap fallback path ────────────────────────────────────────────────
+
+  @Test
+  public void testHashMapPathSelectedWhenProductOverflows() {
+    DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+    Assert.assertTrue(wrapper.isHashMapPath(), "should select HashMap path for 
large key space");
+    Assert.assertTrue(wrapper.isMultiKey());
+  }
+
+  @Test
+  public void testHashMapPathNewKeyGetsSequentialId() {
+    DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+    Assert.assertEquals(wrapper.getCompositeCorrelationId(new int[]{0, 0}), 0);
+    Assert.assertEquals(wrapper.getCompositeCorrelationId(new int[]{0, 1}), 1);
+    Assert.assertEquals(wrapper.getCompositeCorrelationId(new int[]{1, 0}), 2);
+  }
+
+  @Test
+  public void testHashMapPathSameKeyReturnsSameId() {
+    DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+    int first = wrapper.getCompositeCorrelationId(new int[]{5, 7});
+    int second = wrapper.getCompositeCorrelationId(new int[]{5, 7});
+    Assert.assertEquals(first, second);
+  }
+
+  @Test
+  public void testHashMapPathKeyOrderSensitive() {
+    DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+    int id01 = wrapper.getCompositeCorrelationId(new int[]{0, 1});
+    int id10 = wrapper.getCompositeCorrelationId(new int[]{1, 0});
+    Assert.assertNotEquals(id01, id10, "[0,1] and [1,0] must map to different 
IDs");
+  }
+
+  @Test
+  public void testHashMapPathReverseRoundTrip() {
+    DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+    int[][] keys = {{0, 0}, {0, 1}, {1, 0}, {99999, 99999}, {42, 7}};
+    for (int[] key : keys) {
+      int id = wrapper.getCompositeCorrelationId(key);
+      int[] out = new int[2];
+      wrapper.reverseCompositeId(id, out);
+      Assert.assertEquals(out, key, "reverseCompositeId must round-trip for 
key " + Arrays.toString(key));
+    }
+  }
+
+  @Test
+  public void testHashMapPathThreeColumns() {
+    DictIdsWrapper wrapper = new DictIdsWrapper(3, largeDicts(3));
+    int id = wrapper.getCompositeCorrelationId(new int[]{1, 2, 3});
+    int[] out = new int[3];
+    wrapper.reverseCompositeId(id, out);
+    Assert.assertEquals(out, new int[]{1, 2, 3});
+  }
+
+  // ── Stride path reverseCompositeId ──────────────────────────────────────
+
+  @Test
+  public void testStridePathReverseRoundTrip() {
+    Dictionary[] dicts = {mockDict(10), mockDict(20), mockDict(5)};
+    DictIdsWrapper wrapper = new DictIdsWrapper(3, dicts);
+    Assert.assertFalse(wrapper.isHashMapPath(), "should select stride path for 
small key space");
+
+    int[][] keys = {{0, 0, 0}, {9, 19, 4}, {3, 7, 2}, {0, 1, 0}};
+    for (int[] key : keys) {
+      int id = wrapper.getCompositeCorrelationId(key);
+      int[] out = new int[3];
+      wrapper.reverseCompositeId(id, out);
+      Assert.assertEquals(out, key, "stride reverseCompositeId must round-trip 
for key " + Arrays.toString(key));
+    }
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResultTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResultTest.java
new file mode 100644
index 00000000000..7e265376c6f
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResultTest.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function.funnel;
+
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+
+public class SortedAggregationResultTest {
+
+  @Test
+  public void testMultiKeyExtractResultDoesNotDoubleCount() {
+    // Two entities (primary key 0 and 1), each completing both steps.
+    // Expected: stepCounters = [2, 2] (one completion per entity per step).
+    SortedAggregationResult result = new SortedAggregationResult(2, 2);
+    result.addMultiKey(0, new int[]{0, 10});
+    result.addMultiKey(1, new int[]{0, 10});
+    result.addMultiKey(0, new int[]{1, 20});
+    result.addMultiKey(1, new int[]{1, 20});
+
+    LongArrayList counts = result.extractResult();
+    Assert.assertEquals(counts.getLong(0), 2L, "step 0 count");
+    Assert.assertEquals(counts.getLong(1), 2L, "step 1 count");
+  }
+
+  @Test
+  public void testMultiKeySecondaryKeysWithinPrimaryGroup() {
+    // Primary key 0 with two secondary keys: (0,10) and (0,20).
+    // (0,10) completes both steps; (0,20) completes only step 0.
+    // Expected: stepCounters = [2, 1].
+    SortedAggregationResult result = new SortedAggregationResult(2, 2);
+    result.addMultiKey(0, new int[]{0, 10});
+    result.addMultiKey(1, new int[]{0, 10});
+    result.addMultiKey(0, new int[]{0, 20});
+
+    LongArrayList counts = result.extractResult();
+    Assert.assertEquals(counts.getLong(0), 2L, "step 0 count");
+    Assert.assertEquals(counts.getLong(1), 1L, "step 1 count");
+  }
+}
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
index c18674fa9b2..ae13706243b 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
@@ -93,10 +93,17 @@ import static org.testng.Assert.assertNotNull;
  *
  * <h3>Expected funnel counts</h3>
  * <pre>
- * Overall:       [12, 10, 6, 3]
- * clothing:      [ 4,  4, 2, 2]
- * electronics:   [ 5,  4, 2, 1]
- * home:          [ 3,  2, 0, 0]
+ * Single-key CORRELATE_BY(user_id):
+ *   Overall:       [12, 10, 6, 3]
+ *   clothing:      [ 4,  4, 2, 2]
+ *   electronics:   [ 5,  4, 2, 1]
+ *   home:          [ 3,  2, 0, 0]
+ *
+ * Multi-key CORRELATE_BY(user_id, category):
+ *   Overall:       [12, 10, 4, 3]   (step 3 drops from 6→4: users 3,9 
cross-category)
+ *   clothing:      [ 4,  4, 2, 2]   (same — grouping already separates by 
category)
+ *   electronics:   [ 5,  4, 2, 1]   (same)
+ *   home:          [ 3,  2, 0, 0]   (same)
  * </pre>
  */
 @Test(suiteName = "CustomClusterIntegrationTest")
@@ -118,6 +125,11 @@ public class FunnelCountTest extends 
CustomDataQueryClusterIntegrationTest {
   private static final long[] EXPECTED_CLOTHING = {4, 4, 2, 2};
   private static final long[] EXPECTED_HOME = {3, 2, 0, 0};
 
+  // Multi-key: CORRELATE_BY(user_id, category)
+  // Cross-category users 3 and 9 no longer complete checkout within a single 
(user, category) pair.
+  private static final long[] EXPECTED_MULTI_KEY_OVERALL = {12, 10, 4, 3};
+  private static final long[] EXPECTED_MULTI_KEY_FILTERED = {7, 6, 3, 3};
+
   @Override
   protected long getCountStarResult() {
     return COUNT_STAR;
@@ -231,6 +243,23 @@ public class FunnelCountTest extends 
CustomDataQueryClusterIntegrationTest {
         CATEGORY_COL, funnelCountAggregation(settings), TABLE_NAME, 
CATEGORY_COL, CATEGORY_COL);
   }
 
+  private String funnelCountMultiKeyAggregation(String settings) {
+    String settingsClause = (settings == null) ? "" : ", SETTINGS(" + settings 
+ ")";
+    return String.format("FUNNEL_COUNT("
+        + "STEPS(%1$s = '%2$s', %1$s = '%3$s', %1$s = '%4$s', %1$s = '%5$s'), "
+        + "CORRELATE_BY(%6$s, %7$s)"
+        + "%8$s)", ACTION_COL, VIEW, CART, CHECKOUT, PURCHASE, USER_ID_COL, 
CATEGORY_COL, settingsClause);
+  }
+
+  private String overallMultiKeyQuery(String settings) {
+    return String.format("SELECT %s FROM %s", 
funnelCountMultiKeyAggregation(settings), TABLE_NAME);
+  }
+
+  private String groupByMultiKeyQuery(String settings) {
+    return String.format("SELECT %s, %s FROM %s GROUP BY %s ORDER BY %s",
+        CATEGORY_COL, funnelCountMultiKeyAggregation(settings), TABLE_NAME, 
CATEGORY_COL, CATEGORY_COL);
+  }
+
   // ---------- assertion helpers ----------
 
   private JsonNode getRows(JsonNode response) {
@@ -394,4 +423,70 @@ public class FunnelCountTest extends 
CustomDataQueryClusterIntegrationTest {
     JsonNode rows = getRows(postQuery(emptyResultGroupByQuery(settings)));
     assertEquals(rows.size(), 0, "Expected zero groups when all rows are 
filtered");
   }
+
+  // ===================== Multi-key CORRELATE_BY tests =====================
+
+  @Test(dataProvider = "allStrategies")
+  public void testMultiKeyOverall(String settings)
+      throws Exception {
+    setUseMultiStageQueryEngine(false);
+    JsonNode rows = getRows(postQuery(overallMultiKeyQuery(settings)));
+    assertOverallResult(rows, EXPECTED_MULTI_KEY_OVERALL);
+  }
+
+  @Test(dataProvider = "allStrategies")
+  public void testMultiKeyGroupBy(String settings)
+      throws Exception {
+    setUseMultiStageQueryEngine(false);
+    JsonNode rows = getRows(postQuery(groupByMultiKeyQuery(settings)));
+    // Group-by category with CORRELATE_BY(user_id, category) produces the 
same results
+    // as single-key because grouping already separates rows by category.
+    assertGroupByResult(rows);
+  }
+
+  private String filteredMultiKeyQuery(String settings) {
+    return overallMultiKeyQuery(settings) + " WHERE " + USER_ID_COL + " <= 7";
+  }
+
+  @Test(dataProvider = "allStrategies")
+  public void testMultiKeyWithFilter(String settings)
+      throws Exception {
+    setUseMultiStageQueryEngine(false);
+    JsonNode rows = getRows(postQuery(filteredMultiKeyQuery(settings)));
+    assertEquals(rows.size(), 1);
+    assertStepCounts(rows.get(0).get(0), EXPECTED_MULTI_KEY_FILTERED);
+  }
+
+  // Multi-key: WHERE filter eliminates one segment entirely (users 7-12 only)
+  // user 9 crosses categories, so (user=9,home) only does view+cart, 
(user=9,electronics) only does checkout
+  // Expected: view=6, cart=5, checkout=2, purchase=1
+  private static final long[] EXPECTED_MULTI_KEY_ONE_SEGMENT = {6, 5, 2, 1};
+
+  @Test(dataProvider = "allStrategies")
+  public void testMultiKeyFilterEliminatesOneSegment(String settings)
+      throws Exception {
+    setUseMultiStageQueryEngine(false);
+    String query = overallMultiKeyQuery(settings) + " WHERE " + USER_ID_COL + 
" >= 7";
+    JsonNode rows = getRows(postQuery(query));
+    assertOverallResult(rows, EXPECTED_MULTI_KEY_ONE_SEGMENT);
+  }
+
+  @Test(dataProvider = "allStrategies")
+  public void testMultiKeyEmptyResultOverall(String settings)
+      throws Exception {
+    setUseMultiStageQueryEngine(false);
+    String query = overallMultiKeyQuery(settings) + " WHERE " + USER_ID_COL + 
" > 100";
+    JsonNode rows = getRows(postQuery(query));
+    assertOverallResult(rows, EXPECTED_ALL_FILTERED);
+  }
+
+  @Test(dataProvider = "allStrategies")
+  public void testMultiKeyEmptyResultGroupBy(String settings)
+      throws Exception {
+    setUseMultiStageQueryEngine(false);
+    String query = String.format("SELECT %s, %s FROM %s WHERE %s > 100 GROUP 
BY %s ORDER BY %s",
+        CATEGORY_COL, funnelCountMultiKeyAggregation(settings), TABLE_NAME, 
USER_ID_COL, CATEGORY_COL, CATEGORY_COL);
+    JsonNode rows = getRows(postQuery(query));
+    assertEquals(rows.size(), 0, "Expected zero groups when all rows are 
filtered");
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to