This is an automated email from the ASF dual-hosted git repository.

gianm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 06ef24c1e09 feat: ingest support for numeric typed 
ExpressionLambdaAggregatorFactory (#19508)
06ef24c1e09 is described below

commit 06ef24c1e09a8000c91dd0767e09278b34066252
Author: Clint Wylie <[email protected]>
AuthorDate: Sat May 23 14:40:06 2026 -0700

    feat: ingest support for numeric typed ExpressionLambdaAggregatorFactory 
(#19508)
---
 docs/querying/aggregations.md                      |   6 +-
 .../embedded/compact/CompactionTaskTest.java       |  90 ++++++
 .../ExpressionLambdaAggregatorFactory.java         |  82 ++++++
 .../ExpressionLambdaAggregationTest.java           | 215 +++++++++++++++
 .../ExpressionLambdaAggregatorFactoryTest.java     | 304 +++++++++++++++++++++
 5 files changed, 696 insertions(+), 1 deletion(-)

diff --git a/docs/querying/aggregations.md b/docs/querying/aggregations.md
index c7b7d4e4efc..3add7863c46 100644
--- a/docs/querying/aggregations.md
+++ b/docs/querying/aggregations.md
@@ -471,7 +471,11 @@ For these reasons, we have deprecated this aggregator and 
recommend using the Da
 
 ### Expression aggregator
 
-Aggregator applicable only at query time. Aggregates results using [Druid 
expressions](./math-expr.md) functions to facilitate building custom functions.
+Aggregates results using [Druid expressions](./math-expr.md) functions to 
facilitate building custom functions.
+
+The expression aggregator can be used at query time with any intermediate 
type. It can also be used at ingest time, but
+only when the type of `initialValue` is a primitive numeric type (`LONG` or 
`DOUBLE`) and matches the type of
+`initialCombineValue`. Other intermediate types, such as strings, arrays, and 
complex types, are query-time only.
 
 | Property | Description | Required |
 | --- | --- | --- |
diff --git 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
index 84ee947c846..4692ec0715f 100644
--- 
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
+++ 
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
@@ -33,9 +33,12 @@ import 
org.apache.druid.java.util.common.granularity.Granularities;
 import org.apache.druid.java.util.common.granularity.Granularity;
 import org.apache.druid.java.util.common.jackson.JacksonUtils;
 import org.apache.druid.query.Druids;
+import org.apache.druid.query.aggregation.CountAggregatorFactory;
+import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
 import org.apache.druid.query.aggregation.datasketches.hll.HllSketchModule;
 import 
org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchModule;
 import org.apache.druid.query.aggregation.datasketches.theta.SketchModule;
+import org.apache.druid.query.expression.TestExprMacroTable;
 import org.apache.druid.query.metadata.metadata.SegmentMetadataQuery;
 import org.apache.druid.segment.TestHelper;
 import org.apache.druid.testing.embedded.EmbeddedClusterApis;
@@ -55,6 +58,7 @@ import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
@@ -107,6 +111,65 @@ public class CompactionTaskTest extends CompactionTestBase
           "namespace", "continent", "country", "region", "city", "timestamp"
       );
 
+  /**
+   * Index task identical in shape to {@link 
MoreResources.Task#INDEX_TASK_WITH_AGGREGATORS} but with a pair of
+   * {@link ExpressionLambdaAggregatorFactory} metrics over the {@code added} 
long field. Used by
+   * {@link #testCompactionWithExpressionLambdaAggregator} to verify that an 
expression aggregator works correctly.
+   */
+  private static final Supplier<TaskBuilder.Index> INDEX_TASK_WITH_EXPR_AGG = 
() ->
+      TaskBuilder
+          .ofTypeIndex()
+          .jsonInputFormat()
+          .localInputSourceWithFiles(
+              Resources.DataFile.tinyWiki1Json(),
+              Resources.DataFile.tinyWiki2Json(),
+              Resources.DataFile.tinyWiki3Json()
+          )
+          .timestampColumn("timestamp")
+          .dimensions(
+              "page",
+              "language", "tags", "user", "unpatrolled", "newPage", "robot",
+              "anonymous", "namespace", "continent", "country", "region", 
"city"
+          )
+          .metricAggregates(
+              new CountAggregatorFactory("ingested_events"),
+              new ExpressionLambdaAggregatorFactory(
+                  "added_sum_expr",
+                  Set.of("added"),
+                  null,
+                  "0",
+                  null,
+                  null,
+                  false,
+                  false,
+                  "__acc + added",
+                  null,
+                  null,
+                  null,
+                  null,
+                  TestExprMacroTable.INSTANCE
+              ),
+              new ExpressionLambdaAggregatorFactory(
+                  "added_or_expr",
+                  Set.of("added"),
+                  null,
+                  "0",
+                  null,
+                  null,
+                  false,
+                  false,
+                  "bitwiseOr(\"__acc\", \"added\")",
+                  null,
+                  null,
+                  null,
+                  null,
+                  TestExprMacroTable.INSTANCE
+              )
+          )
+          .dynamicPartitionWithMaxRows(3)
+          .granularitySpec("DAY", "SECOND", true)
+          .appendToExisting(false);
+
   private String fullDatasourceName;
 
   @BeforeEach
@@ -259,6 +322,33 @@ public class CompactionTaskTest extends CompactionTestBase
     loadDataAndCompact(INDEX_TASK_WITH_TIMESTAMP.get(), COMPACTION_TASK.get(), 
null);
   }
 
+  @Test
+  public void testCompactionWithExpressionLambdaAggregator() throws Exception
+  {
+    try (final Closeable ignored = unloader(fullDatasourceName)) {
+      runTask(INDEX_TASK_WITH_EXPR_AGG.get());
+      verifySegmentsCount(4);
+
+      // Snapshot metric values prior to compaction.
+      final String preCompact = cluster.runSql(
+          "SELECT SUM(added_sum_expr), SUM(added_or_expr) FROM %s",
+          fullDatasourceName
+      );
+
+      // Compact 4 segments -> 2; this performs cross-segment rollup which 
drives RowCombiningTimeAndDimsIterator
+      // into ExpressionLambdaAggregatorFactory.makeAggregateCombiner().
+      compactData(COMPACTION_TASK.get(), null, null);
+      verifySegmentsCount(2);
+
+      // Metric values must round-trip through compaction unchanged.
+      final String postCompact = cluster.runSql(
+          "SELECT SUM(added_sum_expr), SUM(added_or_expr) FROM %s",
+          fullDatasourceName
+      );
+      Assertions.assertEquals(preCompact, postCompact);
+    }
+  }
+
   private void loadDataAndCompact(
       TaskBuilder.Index indexTask,
       TaskBuilder.Compact compactionResource,
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
index 3235d709eee..c901b52962f 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
@@ -40,9 +40,11 @@ import org.apache.druid.math.expr.SettableObjectBinding;
 import org.apache.druid.query.cache.CacheKeyBuilder;
 import org.apache.druid.segment.ColumnInspector;
 import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
 import org.apache.druid.segment.column.ColumnCapabilities;
 import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
 import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.ValueType;
 import org.apache.druid.segment.virtual.ExpressionPlan;
 import org.apache.druid.segment.virtual.ExpressionPlanner;
 import org.apache.druid.segment.virtual.ExpressionSelectors;
@@ -347,6 +349,86 @@ public class ExpressionLambdaAggregatorFactory extends 
AggregatorFactory
     ).value();
   }
 
+  @Override
+  public AggregateCombiner makeAggregateCombiner()
+  {
+    final ColumnType intermediateType = getIntermediateType();
+    // The combiner delegates to combine(), which feeds inputs into 
combineExpression typed against initialCombineValue.
+    // If the fold-side intermediate type (what's stored in the segment 
column) differs from the combine-side type,
+    // the primitive selector would silently feed wrong-typed values into the 
expression. Fall through to UOE.
+    if 
(!intermediateType.equals(ExpressionType.toColumnType(initialCombineValue.get().type())))
 {
+      return super.makeAggregateCombiner();
+    }
+    if (intermediateType.is(ValueType.LONG)) {
+      return new LongAggregateCombiner()
+      {
+        private long state;
+        private boolean isNull;
+
+        @Override
+        public void reset(ColumnValueSelector selector)
+        {
+          state = selector.getLong();
+          isNull = selector.isNull();
+        }
+
+        @Override
+        public void fold(ColumnValueSelector selector)
+        {
+          final Object combined = combine(isNull ? null : state, 
selector.getObject());
+          isNull = combined == null;
+          state = combined == null ? 0L : ((Number) combined).longValue();
+        }
+
+        @Override
+        public long getLong()
+        {
+          return state;
+        }
+
+        @Override
+        public boolean isNull()
+        {
+          return isNull;
+        }
+      };
+    } else if (intermediateType.is(ValueType.DOUBLE)) {
+      return new DoubleAggregateCombiner()
+      {
+        private double state;
+        private boolean isNull;
+
+        @Override
+        public void reset(ColumnValueSelector selector)
+        {
+          state = selector.getDouble();
+          isNull = selector.isNull();
+        }
+
+        @Override
+        public void fold(ColumnValueSelector selector)
+        {
+          final Object combined = combine(isNull ? null : state, 
selector.getObject());
+          isNull = combined == null;
+          state = combined == null ? 0.0 : ((Number) combined).doubleValue();
+        }
+
+        @Override
+        public double getDouble()
+        {
+          return state;
+        }
+
+        @Override
+        public boolean isNull()
+        {
+          return isNull;
+        }
+      };
+    }
+    return super.makeAggregateCombiner();
+  }
+
   @Override
   public Object deserialize(Object object)
   {
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregationTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregationTest.java
new file mode 100644
index 00000000000..baef2de3a0a
--- /dev/null
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregationTest.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import org.apache.druid.data.input.InputRow;
+import org.apache.druid.data.input.MapBasedInputRow;
+import org.apache.druid.data.input.impl.DimensionsSpec;
+import org.apache.druid.data.input.impl.StringDimensionSchema;
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.guava.Sequence;
+import org.apache.druid.query.Druids;
+import org.apache.druid.query.Result;
+import org.apache.druid.query.expression.TestExprMacroTable;
+import org.apache.druid.query.timeseries.TimeseriesQuery;
+import org.apache.druid.query.timeseries.TimeseriesResultValue;
+import org.apache.druid.segment.IndexBuilder;
+import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.QueryableIndexSegment;
+import org.apache.druid.segment.Segment;
+import org.apache.druid.segment.incremental.IncrementalIndexSchema;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.apache.druid.timeline.SegmentId;
+import org.apache.druid.utils.CloseableUtils;
+import org.joda.time.DateTime;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Verifies that {@link ExpressionLambdaAggregatorFactory} can be used as an 
ingest-time metric for primitive numeric
+ * types.
+ */
+public class ExpressionLambdaAggregationTest extends 
InitializedNullHandlingTest
+{
+  private static final String DIM = "groupKey";
+  private static final String LONG_FIELD = "longField";
+  private static final String DOUBLE_FIELD = "doubleField";
+  private static final DateTime TIMESTAMP = DateTimes.of("2020-01-01");
+
+  @Rule
+  public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+  private QueryableIndex mergedIndex;
+  private Segment segment;
+
+  @After
+  public void tearDown()
+  {
+    if (segment != null) {
+      CloseableUtils.closeAndWrapExceptions(segment);
+    }
+    if (mergedIndex != null) {
+      CloseableUtils.closeAndWrapExceptions(mergedIndex);
+    }
+  }
+
+  @Test
+  public void testNumericExpressionLambdaIngestRollupViaMerge() throws 
Exception
+  {
+    // Three rows sharing the same (timestamp, dim) so they roll up into a 
single output row during merge.
+    // longField values: 1 (0b001), 2 (0b010), 4 (0b100) -> sum=7, bitwiseOr=7
+    // doubleField values: 1.5, 2.0, 0.25 -> sum=3.75
+    final List<InputRow> rows = List.of(
+        row(1L, 1.5),
+        row(2L, 2.0),
+        row(4L, 0.25)
+    );
+
+    final ExpressionLambdaAggregatorFactory longSum = new 
ExpressionLambdaAggregatorFactory(
+        "long_sum",
+        Set.of(LONG_FIELD),
+        null,
+        "0",
+        null,
+        null,
+        false,
+        false,
+        "__acc + " + LONG_FIELD,
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    // BitwiseSqlAggregator-style: same single-field, op("__acc", field) fold
+    final ExpressionLambdaAggregatorFactory bitwiseOr = new 
ExpressionLambdaAggregatorFactory(
+        "bitwise_or",
+        ImmutableSet.of(LONG_FIELD),
+        null,
+        "0",
+        null,
+        null,
+        false,
+        false,
+        "bitwiseOr(\"__acc\", \"" + LONG_FIELD + "\")",
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    final ExpressionLambdaAggregatorFactory doubleSum = new 
ExpressionLambdaAggregatorFactory(
+        "double_sum",
+        ImmutableSet.of(DOUBLE_FIELD),
+        null,
+        "0.0",
+        null,
+        null,
+        false,
+        false,
+        "__acc + " + DOUBLE_FIELD,
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    final IncrementalIndexSchema schema = IncrementalIndexSchema.builder()
+        .withQueryGranularity(Granularities.NONE)
+        .withRollup(true)
+        .withDimensionsSpec(
+            DimensionsSpec.builder()
+                          .setDimensions(ImmutableList.of(new 
StringDimensionSchema(DIM)))
+                          .build()
+        )
+        .withMetrics(
+            new CountAggregatorFactory("count"),
+            longSum,
+            bitwiseOr,
+            doubleSum
+        )
+        .build();
+
+    mergedIndex = IndexBuilder.create()
+                              .tmpDir(tempFolder.newFolder())
+                              .schema(schema)
+                              .intermediaryPersistSize(1)
+                              .rows(rows)
+                              .buildMMappedMergedIndex();
+
+    segment = new QueryableIndexSegment(mergedIndex, SegmentId.dummy("test"));
+
+    final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
+                                        .dataSource("test")
+                                        .granularity(Granularities.ALL)
+                                        .intervals("1970/2050")
+                                        .aggregators(
+                                            new 
LongSumAggregatorFactory("count", "count"),
+                                            longSum.getCombiningFactory(),
+                                            bitwiseOr.getCombiningFactory(),
+                                            doubleSum.getCombiningFactory()
+                                        )
+                                        .build();
+
+    try (final AggregationTestHelper helper =
+             
AggregationTestHelper.createTimeseriesQueryAggregationTestHelper(Collections.emptyList(),
 tempFolder)) {
+
+      final Sequence<Result<TimeseriesResultValue>> seq = 
helper.runQueryOnSegmentsObjs(
+          ImmutableList.of(segment),
+          query
+      );
+      final TimeseriesResultValue result = 
Iterables.getOnlyElement(seq.toList()).getValue();
+
+      // Three input rows rolled up into one, count reflects rollup happened
+      Assert.assertEquals(3L, result.getLongMetric("count").longValue());
+      Assert.assertEquals(7L, result.getLongMetric("long_sum").longValue());
+      Assert.assertEquals(7L, result.getLongMetric("bitwise_or").longValue());
+      Assert.assertEquals(3.75, 
result.getDoubleMetric("double_sum").doubleValue(), 0.0);
+    }
+  }
+
+  private static InputRow row(long longVal, double doubleVal)
+  {
+    return new MapBasedInputRow(
+        TIMESTAMP,
+        ImmutableList.of(DIM),
+        ImmutableMap.of(
+            DIM, "a",
+            LONG_FIELD, longVal,
+            DOUBLE_FIELD, doubleVal
+        )
+    );
+  }
+}
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
index 499bcef08fe..29bf850d3d4 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
@@ -24,24 +24,31 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import nl.jqno.equalsverifier.EqualsVerifier;
 import org.apache.druid.java.util.common.HumanReadableBytes;
+import org.apache.druid.java.util.common.UOE;
 import org.apache.druid.java.util.common.granularity.Granularities;
 import org.apache.druid.query.Druids;
 import 
org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
 import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
 import 
org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
 import org.apache.druid.query.expression.TestExprMacroTable;
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
 import org.apache.druid.query.timeseries.TimeseriesQuery;
 import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
+import org.apache.druid.segment.ColumnValueSelector;
 import org.apache.druid.segment.TestHelper;
 import org.apache.druid.segment.column.ColumnType;
 import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.selector.TestColumnValueSelector;
 import org.apache.druid.testing.InitializedNullHandlingTest;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 
+import javax.annotation.Nullable;
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
 
 public class ExpressionLambdaAggregatorFactoryTest extends 
InitializedNullHandlingTest
 {
@@ -545,6 +552,303 @@ public class ExpressionLambdaAggregatorFactoryTest 
extends InitializedNullHandli
     Assert.assertEquals(ColumnType.DOUBLE, agg.getResultType());
   }
 
+  @Test
+  public void testLongAggregateCombiner()
+  {
+    ExpressionLambdaAggregatorFactory agg = new 
ExpressionLambdaAggregatorFactory(
+        "expr_agg_name",
+        ImmutableSet.of("x"),
+        null,
+        "0",
+        null,
+        true,
+        false,
+        false,
+        "__acc + x",
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    AggregateCombiner combiner = agg.makeAggregateCombiner();
+    TestColumnValueSelector<Long> selector = TestColumnValueSelector.of(
+        Long.class,
+        Arrays.asList(1L, 2L, 3L)
+    );
+    selector.advance();
+    combiner.reset(selector);
+    Assert.assertEquals(1L, combiner.getLong());
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(3L, combiner.getLong());
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(6L, combiner.getLong());
+  }
+
+  @Test
+  public void testDoubleAggregateCombiner()
+  {
+    ExpressionLambdaAggregatorFactory agg = new 
ExpressionLambdaAggregatorFactory(
+        "expr_agg_name",
+        ImmutableSet.of("x"),
+        null,
+        "0.0",
+        null,
+        true,
+        false,
+        false,
+        "__acc + x",
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    AggregateCombiner combiner = agg.makeAggregateCombiner();
+    TestColumnValueSelector<Double> selector = TestColumnValueSelector.of(
+        Double.class,
+        Arrays.asList(1.5, 2.25, 0.25)
+    );
+    selector.advance();
+    combiner.reset(selector);
+    Assert.assertEquals(1.5, combiner.getDouble(), 0.0);
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(3.75, combiner.getDouble(), 0.0);
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(4.0, combiner.getDouble(), 0.0);
+  }
+
+  @Test
+  public void testNullableAggregateCombinerSkipsNulls()
+  {
+    ExpressionLambdaAggregatorFactory agg = new 
ExpressionLambdaAggregatorFactory(
+        "expr_agg_name",
+        ImmutableSet.of("x"),
+        null,
+        "0",
+        null,
+        true,
+        false,
+        false,
+        "__acc + x",
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    AggregateCombiner combiner = agg.makeNullableAggregateCombiner();
+    NullableLongSelector selector = new 
NullableLongSelector(Arrays.asList(null, 5L, null, 7L));
+    selector.advance();
+    combiner.reset(selector);
+    Assert.assertTrue(combiner.isNull());
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertFalse(combiner.isNull());
+    Assert.assertEquals(5L, combiner.getLong());
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(5L, combiner.getLong());
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(12L, combiner.getLong());
+  }
+
+  @Test
+  public void 
testNullableAggregateCombinerWhenCombineAggregatesNullsExpressionSeesNulls()
+  {
+    // shouldCombineAggregateNullInputs=true means the combine expression sees 
null inputs directly. The expression
+    // itself is responsible for handling them; here `nvl` coalesces nulls to 
0 so the accumulator keeps advancing.
+    ExpressionLambdaAggregatorFactory agg = new 
ExpressionLambdaAggregatorFactory(
+        "expr_agg_name",
+        ImmutableSet.of("x"),
+        null,
+        "0",
+        null,
+        true,
+        true,
+        true,
+        "nvl(__acc, 0) + nvl(x, 0)",
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    AggregateCombiner combiner = agg.makeNullableAggregateCombiner();
+    NullableLongSelector selector = new NullableLongSelector(Arrays.asList(1L, 
null, 3L));
+    selector.advance();
+    combiner.reset(selector);
+    Assert.assertEquals(1L, combiner.getLong());
+
+    // null is passed through to the expression, which coalesces to 0
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(1L, combiner.getLong());
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertEquals(4L, combiner.getLong());
+  }
+
+  @Test
+  public void testNullableAggregateCombinerNullExpressionResultPropagates()
+  {
+    // shouldCombineAggregateNullInputs=true with an expression that doesn't 
handle nulls: `__acc + null` evaluates
+    // to null in Druid expression semantics, and the combiner reports isNull 
accordingly.
+    ExpressionLambdaAggregatorFactory agg = new 
ExpressionLambdaAggregatorFactory(
+        "expr_agg_name",
+        ImmutableSet.of("x"),
+        null,
+        "0",
+        null,
+        true,
+        true,
+        true,
+        "__acc + x",
+        null,
+        null,
+        null,
+        null,
+        TestExprMacroTable.INSTANCE
+    );
+
+    AggregateCombiner combiner = agg.makeNullableAggregateCombiner();
+    NullableLongSelector selector = new NullableLongSelector(Arrays.asList(1L, 
null));
+    selector.advance();
+    combiner.reset(selector);
+    Assert.assertFalse(combiner.isNull());
+    Assert.assertEquals(1L, combiner.getLong());
+
+    selector.advance();
+    combiner.fold(selector);
+    Assert.assertTrue(combiner.isNull());
+  }
+
+
+  private static final class NullableLongSelector implements 
ColumnValueSelector<Long>
+  {
+    private final List<Long> values;
+    private int index = -1;
+
+    NullableLongSelector(List<Long> values)
+    {
+      this.values = values;
+    }
+
+    void advance()
+    {
+      index++;
+    }
+
+    @Override
+    public long getLong()
+    {
+      Long v = values.get(index);
+      return v == null ? 0L : v;
+    }
+
+    @Override
+    public double getDouble()
+    {
+      return getLong();
+    }
+
+    @Override
+    public float getFloat()
+    {
+      return getLong();
+    }
+
+    @Override
+    public boolean isNull()
+    {
+      return values.get(index) == null;
+    }
+
+    @Nullable
+    @Override
+    public Long getObject()
+    {
+      return values.get(index);
+    }
+
+    @Override
+    public Class<Long> classOfObject()
+    {
+      return Long.class;
+    }
+
+    @Override
+    public void inspectRuntimeShape(RuntimeShapeInspector inspector)
+    {
+    }
+  }
+
+  @Test(expected = UOE.class)
+  public void testAggregateCombinerNotSupportedForNonNumericTypes()
+  {
+    ExpressionLambdaAggregatorFactory agg = new 
ExpressionLambdaAggregatorFactory(
+        "expr_agg_name",
+        ImmutableSet.of("x"),
+        null,
+        "''",
+        "''",
+        true,
+        true,
+        true,
+        "concat(__acc, x)",
+        "concat(__acc, expr_agg_name)",
+        null,
+        null,
+        new HumanReadableBytes(2048),
+        TestExprMacroTable.INSTANCE
+    );
+
+    agg.makeAggregateCombiner();
+  }
+
+  @Test(expected = UOE.class)
+  public void testAggregateCombinerNotSupportedWhenFoldAndCombineTypesDiffer()
+  {
+    // fold seed is LONG (intermediate column type), but combine seed is 
LONG_ARRAY — combining a long segment column
+    // with an expression that expects arrays would silently produce wrong 
values, so the combiner refuses to handle it.
+    ExpressionLambdaAggregatorFactory agg = new 
ExpressionLambdaAggregatorFactory(
+        "expr_agg_name",
+        ImmutableSet.of("x"),
+        null,
+        "0",
+        "ARRAY<LONG>[]",
+        null,
+        false,
+        false,
+        "__acc + x",
+        "array_set_add(__acc, expr_agg_name)",
+        null,
+        null,
+        new HumanReadableBytes(2048),
+        TestExprMacroTable.INSTANCE
+    );
+
+    Assert.assertEquals(ColumnType.LONG, agg.getIntermediateType());
+    agg.makeAggregateCombiner();
+  }
+
   @Test
   public void testResultArraySignature()
   {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to