This is an automated email from the ASF dual-hosted git repository.
soumyava pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new 93cd6386458 Enabling aggregateMultipleValues in all
StringAnyAggregators (#15434)
93cd6386458 is described below
commit 93cd63864581b16d6b9407723541616a7465da68
Author: Pranav <[email protected]>
AuthorDate: Wed Nov 29 14:32:49 2023 -0800
Enabling aggregateMultipleValues in all StringAnyAggregators (#15434)
* Enabling aggregateMultipleValues in all StringAnyAggregators
* Adding more tests
* More validation
* fix warning
* updating asserts in decoupled mode
* fix intellij inspection
* Addressing comments
* Addressing comments
* Adding early validations and make aggregate consistent across all
* fixing tests
* fixing tests
* Update docs/querying/sql-aggregations.md
Co-authored-by: Clint Wylie <[email protected]>
* fixing static check
---------
Co-authored-by: Clint Wylie <[email protected]>
---
docs/querying/sql-aggregations.md | 2 +-
docs/querying/sql-functions.md | 2 +-
.../query/aggregation/any/StringAnyAggregator.java | 34 +++++-
.../any/StringAnyAggregatorFactory.java | 26 ++--
.../aggregation/any/StringAnyBufferAggregator.java | 29 ++++-
.../aggregation/any/StringAnyVectorAggregator.java | 33 ++++-
.../query/aggregation/AggregatorFactoryTest.java | 5 +-
.../aggregation/any/StringAnyAggregationTest.java | 2 +-
.../any/StringAnyAggregatorFactoryTest.java | 136 +++++++++++++++------
.../any/StringAnyBufferAggregatorTest.java | 124 ++++++++++++++-----
.../any/StringAnyVectorAggregatorTest.java | 45 +++++--
.../segment/virtual/FallbackVirtualColumnTest.java | 2 +-
.../builtin/EarliestLatestAnySqlAggregator.java | 61 +++++----
.../builtin/EarliestLatestBySqlAggregator.java | 6 +-
.../expression/DefaultOperandTypeChecker.java | 15 ++-
.../calcite/expression/OperatorConversions.java | 2 +-
.../druid/sql/calcite/CalciteJoinQueryTest.java | 73 ++++++++++-
.../apache/druid/sql/calcite/CalciteQueryTest.java | 34 ++++--
.../apache/druid/sql/calcite/QueryTestBuilder.java | 6 +
website/.spelling | 1 +
20 files changed, 493 insertions(+), 145 deletions(-)
diff --git a/docs/querying/sql-aggregations.md
b/docs/querying/sql-aggregations.md
index d005ce1fd11..b2df640a68f 100644
--- a/docs/querying/sql-aggregations.md
+++ b/docs/querying/sql-aggregations.md
@@ -90,7 +90,7 @@ In the aggregation functions supported by Druid, only
`COUNT`, `ARRAY_AGG`, and
|`EARLIEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the earliest
value of `expr`.<br />The earliest value of `expr` is taken from the row with
the overall earliest non-null value of `timestampExpr`. <br />If the earliest
non-null value of `timestampExpr` appears in multiple rows, the `expr` may be
taken from any of those rows.<br /><br />If `expr` is a string or complex type
`maxBytesPerValue` amount of space is allocated for the aggregation. Strings
longer than this limit ar [...]
|`LATEST(expr, [maxBytesPerValue])`|Returns the latest value of `expr`<br
/>The `expr` must come from a relation with a timestamp column (like `__time`
in a Druid datasource) and the "latest" is taken from the row with the overall
latest non-null value of the timestamp column.<br />If the latest non-null
value of the timestamp column appears in multiple rows, the `expr` may be taken
from any of those rows.<br /><br />If `expr` is a string or complex type
`maxBytesPerValue` amount of spac [...]
|`LATEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the latest value
of `expr`.<br />The latest value of `expr` is taken from the row with the
overall latest non-null value of `timestampExpr`.<br />If the overall latest
non-null value of `timestampExpr` appears in multiple rows, the `expr` may be
taken from any of those rows.<br /><br />If `expr` is a string or complex type
`maxBytesPerValue` amount of space is allocated for the aggregation. Strings
longer than this limit are t [...]
-|`ANY_VALUE(expr, [maxBytesPerValue])`|Returns any value of `expr` including
null. This aggregator can simplify and optimize the performance by returning
the first encountered value (including `null`).<br /><br />If `expr` is a
string or complex type `maxBytesPerValue` amount of space is allocated for the
aggregation. Strings longer than this limit are truncated. The
`maxBytesPerValue` parameter should be set as low as possible, since high
values will lead to wasted memory.<br/>If `maxBy [...]
+|`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`|Returns any
value of `expr` including null. This aggregator can simplify and optimize the
performance by returning the first encountered value (including `null`).<br
/><br />If `expr` is a string or complex type `maxBytesPerValue` amount of
space is allocated for the aggregation. Strings longer than this limit are
truncated. The `maxBytesPerValue` parameter should be set as low as possible,
since high values will lead to w [...]
|`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy
dimension is included in a row, when using `GROUPING SETS`. Refer to
[additional documentation](aggregations.md#grouping-aggregator) on how to infer
this number.|N/A|
|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY,
including null values, with `size` in bytes limit on aggregation size (default
of 1024 bytes). If the aggregated array grows larger than the maximum size in
bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression
is not currently supported, and the ordering of results within the output array
may vary depending on processing order.|`null`|
|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr`
into an ARRAY, including null values, with `size` in bytes limit on aggregation
size (default of 1024 bytes) per aggregate. If the aggregated array grows
larger than the maximum size in bytes, the query will fail. Use of `ORDER BY`
within the `ARRAY_AGG` expression is not currently supported, and the ordering
of results will be based on the default for the element type.|`null`|
diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md
index 8e43076518d..47b8ca90434 100644
--- a/docs/querying/sql-functions.md
+++ b/docs/querying/sql-functions.md
@@ -50,7 +50,7 @@ Calculates the arc cosine of a numeric expression.
## ANY_VALUE
-`ANY_VALUE(expr, [maxBytesPerValue])`
+`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`
**Function type:** [Aggregation](sql-aggregations.md)
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java
index 352b65e5646..aae267364c4 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java
@@ -24,19 +24,23 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.DimensionHandlerUtils;
+import java.util.List;
+
public class StringAnyAggregator implements Aggregator
{
private final BaseObjectColumnValueSelector valueSelector;
private final int maxStringBytes;
private boolean isFound;
private String foundValue;
+ private final boolean aggregateMultipleValues;
- public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int
maxStringBytes)
+ public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int
maxStringBytes, boolean aggregateMultipleValues)
{
this.valueSelector = valueSelector;
this.maxStringBytes = maxStringBytes;
this.foundValue = null;
this.isFound = false;
+ this.aggregateMultipleValues = aggregateMultipleValues;
}
@Override
@@ -44,18 +48,36 @@ public class StringAnyAggregator implements Aggregator
{
if (!isFound) {
final Object object = valueSelector.getObject();
- foundValue = DimensionHandlerUtils.convertObjectToString(object);
- if (foundValue != null && foundValue.length() > maxStringBytes) {
- foundValue = foundValue.substring(0, maxStringBytes);
- }
+ foundValue = StringUtils.fastLooseChop(readValue(object),
maxStringBytes);
isFound = true;
}
}
+ private String readValue(final Object object)
+ {
+ if (object == null) {
+ return null;
+ }
+ if (object instanceof List) {
+ List<Object> objectList = (List) object;
+ if (objectList.size() == 0) {
+ return null;
+ }
+ if (objectList.size() == 1) {
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ if (aggregateMultipleValues) {
+ return DimensionHandlerUtils.convertObjectToString(objectList);
+ }
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ return DimensionHandlerUtils.convertObjectToString(object);
+ }
+
@Override
public Object get()
{
- return StringUtils.chop(foundValue, maxStringBytes);
+ return foundValue;
}
@Override
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java
index 307de0650c3..67682bfde9d 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java
@@ -48,13 +48,15 @@ public class StringAnyAggregatorFactory extends
AggregatorFactory
private final String fieldName;
private final String name;
- protected final int maxStringBytes;
+ private final int maxStringBytes;
+ private final boolean aggregateMultipleValues;
@JsonCreator
public StringAnyAggregatorFactory(
@JsonProperty("name") String name,
@JsonProperty("fieldName") final String fieldName,
- @JsonProperty("maxStringBytes") Integer maxStringBytes
+ @JsonProperty("maxStringBytes") Integer maxStringBytes,
+ @JsonProperty("aggregateMultipleValues") @Nullable final Boolean
aggregateMultipleValues
)
{
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator
name");
@@ -67,18 +69,19 @@ public class StringAnyAggregatorFactory extends
AggregatorFactory
this.maxStringBytes = maxStringBytes == null
?
StringFirstAggregatorFactory.DEFAULT_MAX_STRING_SIZE
: maxStringBytes;
+ this.aggregateMultipleValues = aggregateMultipleValues == null ? true :
aggregateMultipleValues;
}
@Override
public Aggregator factorize(ColumnSelectorFactory metricFactory)
{
- return new
StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName),
maxStringBytes);
+ return new
StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName),
maxStringBytes, aggregateMultipleValues);
}
@Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory
metricFactory)
{
- return new
StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName),
maxStringBytes);
+ return new
StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName),
maxStringBytes, aggregateMultipleValues);
}
@Override
@@ -90,13 +93,15 @@ public class StringAnyAggregatorFactory extends
AggregatorFactory
return new StringAnyVectorAggregator(
null,
selectorFactory.makeMultiValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
- maxStringBytes
+ maxStringBytes,
+ aggregateMultipleValues
);
} else {
return new StringAnyVectorAggregator(
selectorFactory.makeSingleValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
null,
- maxStringBytes
+ maxStringBytes,
+ aggregateMultipleValues
);
}
}
@@ -122,7 +127,7 @@ public class StringAnyAggregatorFactory extends
AggregatorFactory
@Override
public AggregatorFactory getCombiningFactory()
{
- return new StringAnyAggregatorFactory(name, name, maxStringBytes);
+ return new StringAnyAggregatorFactory(name, name, maxStringBytes,
aggregateMultipleValues);
}
@Override
@@ -155,6 +160,11 @@ public class StringAnyAggregatorFactory extends
AggregatorFactory
{
return maxStringBytes;
}
+ @JsonProperty
+ public boolean getAggregateMultipleValues()
+ {
+ return aggregateMultipleValues;
+ }
@Override
public List<String> requiredFields()
@@ -192,7 +202,7 @@ public class StringAnyAggregatorFactory extends
AggregatorFactory
@Override
public AggregatorFactory withName(String newName)
{
- return new StringAnyAggregatorFactory(newName, getFieldName(),
getMaxStringBytes());
+ return new StringAnyAggregatorFactory(newName, getFieldName(),
getMaxStringBytes(), getAggregateMultipleValues());
}
@Override
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java
index 32bb3153fa2..86b8c51a469 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java
@@ -25,6 +25,7 @@ import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.DimensionHandlerUtils;
import java.nio.ByteBuffer;
+import java.util.List;
public class StringAnyBufferAggregator implements BufferAggregator
{
@@ -34,11 +35,13 @@ public class StringAnyBufferAggregator implements
BufferAggregator
private final BaseObjectColumnValueSelector valueSelector;
private final int maxStringBytes;
+ private final boolean aggregateMultipleValues;
- public StringAnyBufferAggregator(BaseObjectColumnValueSelector
valueSelector, int maxStringBytes)
+ public StringAnyBufferAggregator(BaseObjectColumnValueSelector
valueSelector, int maxStringBytes, boolean aggregateMultipleValues)
{
this.valueSelector = valueSelector;
this.maxStringBytes = maxStringBytes;
+ this.aggregateMultipleValues = aggregateMultipleValues;
}
@Override
@@ -51,8 +54,7 @@ public class StringAnyBufferAggregator implements
BufferAggregator
public void aggregate(ByteBuffer buf, int position)
{
if (buf.getInt(position) == NOT_FOUND_FLAG_VALUE) {
- final Object object = valueSelector.getObject();
- String foundValue = DimensionHandlerUtils.convertObjectToString(object);
+ String foundValue = readValue(valueSelector.getObject());
if (foundValue != null) {
ByteBuffer mutationBuffer = buf.duplicate();
mutationBuffer.position(position + FOUND_VALUE_OFFSET);
@@ -65,6 +67,27 @@ public class StringAnyBufferAggregator implements
BufferAggregator
}
}
+ private String readValue(Object object)
+ {
+ if (object == null) {
+ return null;
+ }
+ if (object instanceof List) {
+ List<Object> objectList = (List) object;
+ if (objectList.size() == 0) {
+ return null;
+ }
+ if (objectList.size() == 1) {
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ if (aggregateMultipleValues) {
+ return DimensionHandlerUtils.convertObjectToString(objectList);
+ }
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ return DimensionHandlerUtils.convertObjectToString(object);
+ }
+
@Override
public Object get(ByteBuffer buf, int position)
{
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java
index 620801bafa3..104ee7a77bf 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java
@@ -23,12 +23,15 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.data.IndexedInts;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
public class StringAnyVectorAggregator implements VectorAggregator
{
@@ -43,11 +46,13 @@ public class StringAnyVectorAggregator implements
VectorAggregator
@Nullable
private final MultiValueDimensionVectorSelector multiValueSelector;
private final int maxStringBytes;
+ private final boolean aggregateMultipleValues;
public StringAnyVectorAggregator(
SingleValueDimensionVectorSelector singleValueSelector,
MultiValueDimensionVectorSelector multiValueSelector,
- int maxStringBytes
+ int maxStringBytes,
+ final boolean aggregateMultipleValues
)
{
Preconditions.checkState(
@@ -61,6 +66,7 @@ public class StringAnyVectorAggregator implements
VectorAggregator
this.multiValueSelector = multiValueSelector;
this.singleValueSelector = singleValueSelector;
this.maxStringBytes = maxStringBytes;
+ this.aggregateMultipleValues = aggregateMultipleValues;
}
@Override
@@ -78,7 +84,7 @@ public class StringAnyVectorAggregator implements
VectorAggregator
if (startRow < rows.length) {
IndexedInts row = rows[startRow];
@Nullable
- String foundValue = row.size() == 0 ? null :
multiValueSelector.lookupName(row.get(0));
+ String foundValue = readValue(row);
putValue(buf, position, foundValue);
}
} else if (singleValueSelector != null) {
@@ -93,6 +99,24 @@ public class StringAnyVectorAggregator implements
VectorAggregator
}
}
+ private String readValue(IndexedInts row)
+ {
+ if (row.size() == 0) {
+ return null;
+ }
+ if (aggregateMultipleValues) {
+ if (row.size() == 1) {
+ return multiValueSelector.lookupName(row.get(0));
+ }
+ List<String> arrayList = new ArrayList<>();
+ row.forEach(rowIndex -> {
+ arrayList.add(multiValueSelector.lookupName(rowIndex));
+ });
+ return DimensionHandlerUtils.convertObjectToString(arrayList);
+ }
+ return multiValueSelector.lookupName(row.get(0));
+ }
+
@Override
public void aggregate(
ByteBuffer buf,
@@ -142,4 +166,9 @@ public class StringAnyVectorAggregator implements
VectorAggregator
buf.putInt(position, FOUND_AND_NULL_FLAG_VALUE);
}
}
+
+ public boolean isAggregateMultipleValues()
+ {
+ return aggregateMultipleValues;
+ }
}
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java
index 87d0e3dfdd8..1fea653c6f7 100644
---
a/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java
@@ -148,7 +148,7 @@ public class AggregatorFactoryTest extends
InitializedNullHandlingTest
// string aggregators
new StringFirstAggregatorFactory("stringFirst", "col", null,
1024),
new StringLastAggregatorFactory("stringLast", "col", null,
1024),
- new StringAnyAggregatorFactory("stringAny", "col", 1024),
+ new StringAnyAggregatorFactory("stringAny", "col", 1024,
true),
// sketch aggs
new CardinalityAggregatorFactory("cardinality",
ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
new HyperUniquesAggregatorFactory("hyperUnique",
"hyperunique"),
@@ -307,7 +307,8 @@ public class AggregatorFactoryTest extends
InitializedNullHandlingTest
// string aggregators
new StringFirstAggregatorFactory("col", "col", null, 1024),
new StringLastAggregatorFactory("col", "col", null, 1024),
- new StringAnyAggregatorFactory("col", "col", 1024),
+ new StringAnyAggregatorFactory("col", "col", 1024, true),
+ new StringAnyAggregatorFactory("col", "col", 1024, false),
// sketch aggs
new CardinalityAggregatorFactory("col",
ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
new HyperUniquesAggregatorFactory("col", "hyperunique"),
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java
index 208cfeb052d..b728049d625 100644
---
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java
@@ -49,7 +49,7 @@ public class StringAnyAggregationTest
@Before
public void setup()
{
- stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly",
MAX_STRING_SIZE);
+ stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly",
MAX_STRING_SIZE, true);
combiningAggFactory = stringAnyAggFactory.getCombiningFactory();
valueSelector = new TestObjectColumnSelector<>(strings);
objectSelector = new TestObjectColumnSelector<>(strings);
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java
index c480b9bb2ef..88351125f55 100644
---
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java
@@ -19,48 +19,44 @@
package org.apache.druid.query.aggregation.any;
-import org.apache.druid.segment.ColumnInspector;
+import com.google.common.collect.Lists;
+import org.apache.druid.query.aggregation.Aggregator;
+import org.apache.druid.query.aggregation.TestObjectColumnSelector;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
-import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
-import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
-import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
+import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory;
+import org.apache.druid.segment.virtual.FallbackVirtualColumnTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.Mockito;
-import org.mockito.junit.MockitoJUnitRunner;
-import static org.mockito.ArgumentMatchers.any;
+import java.util.List;
-@RunWith(MockitoJUnitRunner.class)
public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
{
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final int MAX_STRING_BYTES = 10;
- @Mock
- private ColumnInspector columnInspector;
- @Mock
+ private TestColumnSelectorFactory columnInspector;
private ColumnCapabilities capabilities;
- @Mock
- private VectorColumnSelectorFactory vectorSelectorFactory;
- @Mock
- private SingleValueDimensionVectorSelector
singleValueDimensionVectorSelector;
- @Mock
- private MultiValueDimensionVectorSelector multiValueDimensionVectorSelector;
-
+ private TestVectorColumnSelectorFactory vectorSelectorFactory;
private StringAnyAggregatorFactory target;
@Before
public void setUp()
{
-
Mockito.doReturn(capabilities).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME);
-
Mockito.doReturn(ColumnCapabilities.Capable.UNKNOWN).when(capabilities).hasMultipleValues();
- target = new StringAnyAggregatorFactory(NAME, FIELD_NAME,
MAX_STRING_BYTES);
+ target = new StringAnyAggregatorFactory(NAME, FIELD_NAME,
MAX_STRING_BYTES, true);
+ columnInspector = new TestColumnSelectorFactory();
+ vectorSelectorFactory = new TestVectorColumnSelectorFactory();
+ capabilities =
ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
+ vectorSelectorFactory.addCapabilities(FIELD_NAME, capabilities);
+ vectorSelectorFactory.addMVDVS(FIELD_NAME, new
FallbackVirtualColumnTest.SameMultiVectorSelector());
}
@Test
@@ -72,10 +68,6 @@ public class StringAnyAggregatorFactoryTest extends
InitializedNullHandlingTest
@Test
public void
factorizeVectorWithoutCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{
-
Mockito.doReturn(null).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME);
- Mockito.doReturn(singleValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeSingleValueDimensionSelector(any());
StringAnyVectorAggregator aggregator =
target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
@@ -83,9 +75,6 @@ public class StringAnyAggregatorFactoryTest extends
InitializedNullHandlingTest
@Test
public void
factorizeVectorWithUnknownCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{
- Mockito.doReturn(multiValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeMultiValueDimensionSelector(any());
StringAnyVectorAggregator aggregator =
target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
@@ -93,10 +82,6 @@ public class StringAnyAggregatorFactoryTest extends
InitializedNullHandlingTest
@Test
public void
factorizeVectorWithMultipleValuesCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{
-
Mockito.doReturn(ColumnCapabilities.Capable.TRUE).when(capabilities).hasMultipleValues();
- Mockito.doReturn(multiValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeMultiValueDimensionSelector(any());
StringAnyVectorAggregator aggregator =
target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
@@ -104,11 +89,86 @@ public class StringAnyAggregatorFactoryTest extends
InitializedNullHandlingTest
@Test
public void
factorizeVectorWithoutMultipleValuesCapabilitiesShouldReturnAggregatorWithSingleDimensionSelector()
{
-
Mockito.doReturn(ColumnCapabilities.Capable.FALSE).when(capabilities).hasMultipleValues();
- Mockito.doReturn(singleValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeSingleValueDimensionSelector(any());
StringAnyVectorAggregator aggregator =
target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
+
+ @Test
+ public void testFactorize()
+ {
+ Aggregator res = target.factorize(new TestColumnSelectorFactory());
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ res.aggregate();
+ Assert.assertEquals(null, res.get());
+ StringAnyVectorAggregator vectorAggregator =
target.factorizeVector(vectorSelectorFactory);
+ Assert.assertTrue(vectorAggregator.isAggregateMultipleValues());
+ }
+
+ @Test
+ public void testSvdStringAnyAggregator()
+ {
+ TestColumnSelectorFactory columnSelectorFactory = new
TestColumnSelectorFactory();
+ Aggregator res = target.factorize(columnSelectorFactory);
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ columnSelectorFactory.moveSelectorCursorToNext();
+ res.aggregate();
+ Assert.assertEquals("CCCC", res.get());
+ }
+
+ @Test
+ public void testMvdStringAnyAggregator()
+ {
+ TestColumnSelectorFactory columnSelectorFactory = new
TestColumnSelectorFactory();
+ Aggregator res = target.factorize(columnSelectorFactory);
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ columnSelectorFactory.moveSelectorCursorToNext();
+ columnSelectorFactory.moveSelectorCursorToNext();
+ res.aggregate();
+ Assert.assertEquals("[AAAA, AAA", res.get());
+ }
+
+ @Test
+ public void testMvdStringAnyAggregatorWithAggregateMultipleToFalse()
+ {
+ StringAnyAggregatorFactory target = new StringAnyAggregatorFactory(NAME,
FIELD_NAME, MAX_STRING_BYTES, false);
+ TestColumnSelectorFactory columnSelectorFactory = new
TestColumnSelectorFactory();
+ Aggregator res = target.factorize(columnSelectorFactory);
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ columnSelectorFactory.moveSelectorCursorToNext();
+ columnSelectorFactory.moveSelectorCursorToNext();
+ res.aggregate();
+ // picks up first value in mvd list
+ Assert.assertEquals("AAAA", res.get());
+ }
+
+ static class TestColumnSelectorFactory implements ColumnSelectorFactory
+ {
+ List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
+ final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
+ Integer maxStringBytes = 1024;
+ TestObjectColumnSelector<Object> objectColumnSelector = new
TestObjectColumnSelector<>(mvds);
+
+ @Override
+ public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec)
+ {
+ return null;
+ }
+
+ @Override
+ public ColumnValueSelector<?> makeColumnValueSelector(String columnName)
+ {
+ return objectColumnSelector;
+ }
+
+ @Override
+ public ColumnCapabilities getColumnCapabilities(String columnName)
+ {
+ return ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
+ }
+
+ public void moveSelectorCursorToNext()
+ {
+ objectColumnSelector.increment();
+ }
+ }
}
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java
index 658db6f7eec..1db6cbe2b1d 100644
---
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java
@@ -19,15 +19,22 @@
package org.apache.druid.query.aggregation.any;
+import com.google.common.collect.Lists;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.TestObjectColumnSelector;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
public class StringAnyBufferAggregatorTest
{
+ StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
+ "billy", "billy", 1024, true
+ );
+
private void aggregateBuffer(
TestObjectColumnSelector valueSelector,
BufferAggregator agg,
@@ -44,17 +51,14 @@ public class StringAnyBufferAggregatorTest
{
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new
TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -75,17 +79,15 @@ public class StringAnyBufferAggregatorTest
public void testBufferAggregateWithFoldCheck()
{
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new
TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -108,17 +110,14 @@ public class StringAnyBufferAggregatorTest
{
final String[] strings = {"CCCC", "AAAA", "BBBB", null, "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new
TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -140,17 +139,13 @@ public class StringAnyBufferAggregatorTest
{
final String[] strings = {null, "CCCC", "AAAA", "BBBB", "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new
TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes, true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -170,19 +165,15 @@ public class StringAnyBufferAggregatorTest
@Test
public void testNonStringValue()
{
-
final Double[] doubles = {1.00, 2.00};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector<Double> objectColumnSelector = new
TestObjectColumnSelector<>(doubles);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -198,4 +189,77 @@ public class StringAnyBufferAggregatorTest
Assert.assertEquals("1.0", result);
}
+
+ @Test
+ public void testMvds()
+ {
+ List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
+ final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
+ int maxStringBytes = 1024;
+
+ TestObjectColumnSelector<Object> objectColumnSelector = new
TestObjectColumnSelector<>(mvds);
+
+ StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
+ objectColumnSelector,
+ maxStringBytes, true
+ );
+
+ ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
+ int position = 0;
+
+ int[] positions = new int[]{0, 1, 43, 100, 189};
+ Arrays.stream(positions).forEach(i -> agg.init(buf, i));
+
+ //noinspection ForLoopReplaceableByForEach
+ for (int i = 0; i < mvds.length; i++) {
+ aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
+ }
+ String result = ((String) agg.get(buf, position));
+ Assert.assertNull(result);
+
+ for (int i = 0; i < positions.length; i++) {
+ if (i == 2) {
+ Assert.assertEquals(mvd.toString(), agg.get(buf, positions[2]));
+ } else {
+ Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
+ }
+ }
+ }
+
+ @Test
+ public void testMvdsWithCustomAggregate()
+ {
+ List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
+ final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
+ final int maxStringBytes = 1024;
+
+ TestObjectColumnSelector<Object> objectColumnSelector = new
TestObjectColumnSelector<>(mvds);
+
+ StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
+ objectColumnSelector,
+ maxStringBytes, false
+ );
+
+ ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
+ int position = 0;
+
+ int[] positions = new int[]{0, 1, 43, 100, 189};
+ Arrays.stream(positions).forEach(i -> agg.init(buf, i));
+
+ //noinspection ForLoopReplaceableByForEach
+ for (int i = 0; i < mvds.length; i++) {
+ aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
+ }
+ String result = ((String) agg.get(buf, position));
+ Assert.assertNull(result);
+
+ for (int i = 0; i < positions.length; i++) {
+ if (i == 2) {
+ // takes first in case of mvds
+ Assert.assertEquals(mvd.get(0), agg.get(buf, positions[2]));
+ } else {
+ Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
+ }
+ }
+ }
}
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java
index bb9ca74dfb4..b6555f6d2af 100644
---
a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java
@@ -33,6 +33,8 @@ import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import static
org.apache.druid.query.aggregation.any.StringAnyVectorAggregator.NOT_FOUND_FLAG_VALUE;
@@ -61,6 +63,7 @@ public class StringAnyVectorAggregatorTest extends
InitializedNullHandlingTest
private StringAnyVectorAggregator singleValueTarget;
private StringAnyVectorAggregator multiValueTarget;
+ private StringAnyVectorAggregator customMultiValueTarget;
@Before
public void setUp()
@@ -74,20 +77,22 @@ public class StringAnyVectorAggregatorTest extends
InitializedNullHandlingTest
return index >= DICTIONARY.length ? null : DICTIONARY[index];
}).when(singleValueSelector).lookupName(anyInt());
initializeRandomBuffer();
- singleValueTarget = new StringAnyVectorAggregator(singleValueSelector,
null, MAX_STRING_BYTES);
- multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector,
MAX_STRING_BYTES);
+ singleValueTarget = new StringAnyVectorAggregator(singleValueSelector,
null, MAX_STRING_BYTES, true);
+ multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector,
MAX_STRING_BYTES, true);
+ // customMultiValueTarget aggregates to only single value in case of MVDs
+ customMultiValueTarget = new StringAnyVectorAggregator(null,
multiValueSelector, MAX_STRING_BYTES, false);
}
@Test(expected = IllegalStateException.class)
public void initWithBothSingleAndMultiValueSelectorShouldThrowException()
{
- new StringAnyVectorAggregator(singleValueSelector, multiValueSelector,
MAX_STRING_BYTES);
+ new StringAnyVectorAggregator(singleValueSelector, multiValueSelector,
MAX_STRING_BYTES, true);
}
@Test(expected = IllegalStateException.class)
public void initWithNeitherSingleNorMultiValueSelectorShouldThrowException()
{
- new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES);
+ new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES, true);
}
@Test
@@ -122,7 +127,7 @@ public class StringAnyVectorAggregatorTest extends
InitializedNullHandlingTest
public void aggregateMultiValuePositionNotFoundShouldPutFirstValue()
{
multiValueTarget.aggregate(buf, POSITION, 0, 2);
- Assert.assertEquals(DICTIONARY[1], multiValueTarget.get(buf, POSITION));
+ Assert.assertEquals("[One, Zero]", multiValueTarget.get(buf, POSITION));
}
@Test
@@ -155,9 +160,9 @@ public class StringAnyVectorAggregatorTest extends
InitializedNullHandlingTest
@Test
public void aggregateBatchWithRowsShouldAggregateAllRows()
{
- int[] positions = new int[] {0, 43, 100};
+ int[] positions = new int[]{0, 43, 100};
int positionOffset = 2;
- int[] rows = new int[] {2, 1, 0};
+ int[] rows = new int[]{2, 1, 0};
clearBufferForPositions(positionOffset, positions);
multiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
@@ -166,8 +171,32 @@ public class StringAnyVectorAggregatorTest extends
InitializedNullHandlingTest
IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
if (rowIndex.size() == 0) {
Assert.assertNull(multiValueTarget.get(buf, position));
- } else {
+ } else if (rowIndex.size() == 1) {
Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)),
multiValueTarget.get(buf, position));
+ } else {
+ List<String> res = new ArrayList<>();
+ rowIndex.forEach(index ->
res.add(multiValueSelector.lookupName(index)));
+ Assert.assertEquals(res.toString(), multiValueTarget.get(buf,
position));
+ }
+ }
+ }
+
+ @Test
+ public void
aggregateBatchWithRowsShouldAggregateAllRowsWithAggregateMVDFalse()
+ {
+ int[] positions = new int[]{0, 43, 100};
+ int positionOffset = 2;
+ int[] rows = new int[]{2, 1, 0};
+ clearBufferForPositions(positionOffset, positions);
+ customMultiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows[i];
+ IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
+ if (rowIndex.size() == 0) {
+ Assert.assertNull(customMultiValueTarget.get(buf, position));
+ } else {
+ Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)),
customMultiValueTarget.get(buf, position));
}
}
}
diff --git
a/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java
b/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java
index ac6b5446144..72de5466ff6 100644
---
a/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java
+++
b/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java
@@ -489,7 +489,7 @@ public class FallbackVirtualColumnTest
}
}
- private static class SameMultiVectorSelector implements
MultiValueDimensionVectorSelector
+ public static class SameMultiVectorSelector implements
MultiValueDimensionVectorSelector
{
@Override
public int getValueCardinality()
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
index efa3a9e7e32..21bcc833e04 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
@@ -95,7 +95,8 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
String fieldName,
String timeColumn,
ColumnType type,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
)
{
switch (type.getType()) {
@@ -121,7 +122,8 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
String fieldName,
String timeColumn,
ColumnType type,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
)
{
switch (type.getType()) {
@@ -147,7 +149,8 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
String fieldName,
String timeColumn,
ColumnType type,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
)
{
switch (type.getType()) {
@@ -158,7 +161,7 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
case DOUBLE:
return new DoubleAnyAggregatorFactory(name, fieldName);
case STRING:
- return new StringAnyAggregatorFactory(name, fieldName,
maxStringBytes);
+ return new StringAnyAggregatorFactory(name, fieldName,
maxStringBytes, aggregateMultipleValues);
default:
throw SimpleSqlAggregator.badTypeException(fieldName, "ANY", type);
}
@@ -170,7 +173,8 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
String fieldName,
String timeColumn,
ColumnType outputType,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
);
}
@@ -244,37 +248,38 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
final AggregatorFactory theAggFactory;
switch (args.size()) {
case 1:
- theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName,
fieldName, null, outputType, null);
+ theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName,
fieldName, null, outputType, null, true);
break;
case 2:
- int maxStringBytes;
- try {
- maxStringBytes = RexLiteral.intValue(rexNodes.get(1));
- }
- catch (AssertionError ae) {
- plannerContext.setPlanningError(
- "The second argument '%s' to function '%s' is not a number",
- rexNodes.get(1),
- aggregateCall.getName()
- );
- return null;
- }
+ Integer maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); //
added not null check at the function
theAggFactory = aggregatorType.createAggregatorFactory(
aggregatorName,
fieldName,
null,
outputType,
- maxStringBytes
+ maxStringBytes.intValue(),
+ true
+ );
+ break;
+ case 3:
+ maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not
null check at the function for rexNode 1,2
+ boolean aggregateMultipleValues =
RexLiteral.booleanValue(rexNodes.get(2));
+ theAggFactory = aggregatorType.createAggregatorFactory(
+ aggregatorName,
+ fieldName,
+ null,
+ outputType,
+ maxStringBytes,
+ aggregateMultipleValues
);
break;
default:
throw InvalidSqlInput.exception(
- "Function [%s] expects 1 or 2 arguments but found [%s]",
+ "Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
aggregateCall.getName(),
args.size()
);
}
-
return Aggregation.create(
Collections.singletonList(theAggFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name,
aggregatorName) : null
@@ -372,10 +377,11 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
InferTypes.RETURN_TYPE,
DefaultOperandTypeChecker
.builder()
- .operandNames("expr", "maxBytesPerString")
- .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
+ .operandNames("expr", "maxBytesPerStringInt",
"aggregateMultipleValuesBoolean")
+ .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC,
SqlTypeFamily.BOOLEAN)
.requiredOperandCount(1)
- .literalOperands(1)
+ .literalOperands(1, 2)
+ .notNullOperands(1, 2)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
@@ -402,9 +408,9 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
SqlParserPos pos = call.getParserPosition();
- if (operands.isEmpty() || operands.size() > 2) {
+ if (operands.isEmpty() || operands.size() > 3) {
throw InvalidSqlInput.exception(
- "Function [%s] expects 1 or 2 arguments but found [%s]",
+ "Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
getName(),
operands.size()
);
@@ -417,6 +423,9 @@ public class EarliestLatestAnySqlAggregator implements
SqlAggregator
if (operands.size() == 2) {
newOperands.add(operands.get(1));
}
+ if (operands.size() == 3) {
+ newOperands.add(operands.get(2));
+ }
return replacementAggFunc.createCall(pos, newOperands);
}
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java
index 03e23503a81..fac88d853e1 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java
@@ -119,7 +119,8 @@ public class EarliestLatestBySqlAggregator implements
SqlAggregator
rexNodes.get(1)
),
outputType,
- null
+ null,
+ true
);
break;
case 3:
@@ -145,7 +146,8 @@ public class EarliestLatestBySqlAggregator implements
SqlAggregator
rexNodes.get(1)
),
outputType,
- maxStringBytes
+ maxStringBytes,
+ true
);
break;
default:
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java
index f43fde3a935..a52ce5707c1 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java
@@ -188,6 +188,7 @@ public class DefaultOperandTypeChecker implements
SqlOperandTypeChecker
@Nullable
private Integer requiredOperandCount;
private int[] literalOperands;
+ private IntSet notNullOperands = new IntArraySet();
private Builder()
{
@@ -229,6 +230,12 @@ public class DefaultOperandTypeChecker implements
SqlOperandTypeChecker
return this;
}
+ public Builder notNullOperands(final int... notNullOperands)
+ {
+ Arrays.stream(notNullOperands).forEach(this.notNullOperands::add);
+ return this;
+ }
+
public DefaultOperandTypeChecker build()
{
int computedRequiredOperandCount = requiredOperandCount == null ?
operandTypes.size() : requiredOperandCount;
@@ -236,16 +243,18 @@ public class DefaultOperandTypeChecker implements
SqlOperandTypeChecker
operandNames,
operandTypes,
computedRequiredOperandCount,
-
DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount,
operandTypes.size()),
+
DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount,
operandTypes.size(), notNullOperands),
literalOperands
);
}
}
- public static IntSet buildNullableOperands(int requiredOperandCount, int
totalOperandCount)
+ public static IntSet buildNullableOperands(int requiredOperandCount, int
totalOperandCount, IntSet notNullOperands)
{
final IntSet nullableOperands = new IntArraySet();
- IntStream.range(requiredOperandCount,
totalOperandCount).forEach(nullableOperands::add);
+ IntStream.range(requiredOperandCount, totalOperandCount)
+ .filter(i -> !notNullOperands.contains(i))
+ .forEach(nullableOperands::add);
return nullableOperands;
}
}
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
index d655e12f799..450d6208240 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
@@ -593,7 +593,7 @@ public class OperatorConversions
{
final IntSet nullableOperands = requiredOperandCount == null
? new IntArraySet()
- :
DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount,
operandTypes.size());
+ :
DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount,
operandTypes.size(), new IntArraySet());
if (operandTypeInference == null) {
SqlOperandTypeInference defaultInference = new
DefaultOperandTypeInference(operandTypes, nullableOperands);
return (callBinding, returnType, types) -> {
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
index a8e16de2675..3ead14a05a3 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
@@ -497,7 +497,7 @@ public class CalciteJoinQueryTest extends
BaseCalciteQueryTest
cannotVectorize();
testQuery(
- "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100) FROM foo
WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n"
+ "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, true) FROM
foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n"
+ " (\n"
+ " SELECT TIME_FLOOR(__time, 'PT1H') AS t1, MIN(m1) AS t2 FROM
foo WHERE dim3 = 'b'\n"
+ " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11
00:00:00' GROUP BY 1\n"
@@ -532,7 +532,7 @@ public class CalciteJoinQueryTest extends
BaseCalciteQueryTest
)
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
- new StringAnyAggregatorFactory("a0", "dim3", 100)
+ new StringAnyAggregatorFactory("a0", "dim3", 100,
true)
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
@@ -598,7 +598,7 @@ public class CalciteJoinQueryTest extends
BaseCalciteQueryTest
)
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
- new StringAnyAggregatorFactory("a0", "dim3", 100)
+ new StringAnyAggregatorFactory("a0", "dim3", 100,
true)
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
@@ -609,6 +609,71 @@ public class CalciteJoinQueryTest extends
BaseCalciteQueryTest
)
);
}
+ @Test
+ public void
testJoinOnGroupByInsteadOfTimeseriesWithFloorOnTimeWithNoAggregateMultipleValues()
+ {
+ // Cannot vectorize JOIN operator.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, false) FROM
foo WHERE (CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1, m1) IN\n"
+ + " (\n"
+ + " SELECT CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1 AS t1,
MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n"
+ + " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11
00:00:00' GROUP BY 1\n"
+ + " )\n"
+ + "GROUP BY 1, 2\n",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(
+ join(
+ new TableDataSource(CalciteTests.DATASOURCE1),
+ new QueryDataSource(
+ GroupByQuery.builder()
+
.setDataSource(CalciteTests.DATASOURCE1)
+
.setInterval(querySegmentSpec(Intervals.of(
+
"1994-04-29/2020-01-11T00:00:00.001Z")))
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+
"(timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1)",
+ ColumnType.LONG
+ )
+ )
+ .setDimFilter(equality("dim3",
"b", ColumnType.STRING))
+
.setGranularity(Granularities.ALL)
+ .setDimensions(dimensions(new
DefaultDimensionSpec(
+ "v0",
+ "d0",
+ ColumnType.LONG
+ )))
+
.setAggregatorSpecs(aggregators(
+ new
FloatMinAggregatorFactory("a0", "m1")
+ ))
+
.setContext(QUERY_CONTEXT_DEFAULT)
+ .build()),
+ "j0.",
+
"(((timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1) == \"j0.d0\") && (\"m1\"
== \"j0.a0\"))",
+ JoinType.INNER
+ )
+ )
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setDimensions(
+ new DefaultDimensionSpec("__time", "d0",
ColumnType.LONG),
+ new DefaultDimensionSpec("m1", "d1",
ColumnType.FLOAT)
+ )
+ .setGranularity(Granularities.ALL)
+ .setAggregatorSpecs(aggregators(
+ new StringAnyAggregatorFactory("a0", "dim3", 100,
false)
+ ))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{946684800000L, 1.0f, "a"}, // picks up first from [a,
b]
+ new Object[]{946771200000L, 2.0f, "b"} // picks up first from [b,
c]
+ )
+ );
+ }
@Test
@Parameters(source = QueryContextForJoinProvider.class)
@@ -1480,7 +1545,7 @@ public class CalciteJoinQueryTest extends
BaseCalciteQueryTest
new SubstringDimExtractionFn(0, 1)
)
)
- .setAggregatorSpecs(new
StringAnyAggregatorFactory("a0", "v", 10))
+ .setAggregatorSpecs(new
StringAnyAggregatorFactory("a0", "v", 10, true))
.build()
),
"j0.",
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 00ea933bb1e..1bca129757b 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -844,10 +844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new LongAnyAggregatorFactory("a0", "cnt"),
new FloatAnyAggregatorFactory("a1", "m1"),
new DoubleAnyAggregatorFactory("a2", "m2"),
- new StringAnyAggregatorFactory("a3", "dim1", 10),
+ new StringAnyAggregatorFactory("a3", "dim1", 10,
true),
new LongAnyAggregatorFactory("a4", "v0"),
new FloatAnyAggregatorFactory("a5", "v1"),
- new StringAnyAggregatorFactory("a6", "v2", 10)
+ new StringAnyAggregatorFactory("a6", "v2", 10, true)
)
)
.context(QUERY_CONTEXT_DEFAULT)
@@ -1420,7 +1420,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setAggregatorSpecs(aggregators(new
StringAnyAggregatorFactory(
"a0:a",
"dim1",
- 10
+ 10, true
)))
.setPostAggregatorSpecs(
ImmutableList.of(
@@ -1565,7 +1565,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.granularity(Granularities.ALL)
.aggregators(
aggregators(
- new StringAnyAggregatorFactory("a0", "dim1", 32),
+ new StringAnyAggregatorFactory("a0", "dim1", 32,
true),
new LongAnyAggregatorFactory("a1", "l2"),
new DoubleAnyAggregatorFactory("a2", "d2"),
new FloatAnyAggregatorFactory("a3", "f2")
@@ -1607,7 +1607,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.filters(filter)
.aggregators(
aggregators(
- new StringAnyAggregatorFactory("a0", "dim1", 32),
+ new StringAnyAggregatorFactory("a0", "dim1", 32,
true),
new LongAnyAggregatorFactory("a1", "l2"),
new DoubleAnyAggregatorFactory("a2", "d2"),
new FloatAnyAggregatorFactory("a3", "f2")
@@ -9422,7 +9422,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.granularity(Granularities.ALL)
.aggregators(
aggregators(
- new StringAnyAggregatorFactory("a0", "dim1", 1024),
+ new StringAnyAggregatorFactory("a0", "dim1", 1024,
true),
new LongAnyAggregatorFactory("a1", "l1"),
new StringFirstAggregatorFactory("a2", "dim1", null,
1024),
new LongFirstAggregatorFactory("a3", "l1", null),
@@ -9741,7 +9741,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setAggregatorSpecs(
aggregators(
new FilteredAggregatorFactory(
- new StringAnyAggregatorFactory("a0",
"dim1", 1024),
+ new StringAnyAggregatorFactory("a0",
"dim1", 1024, true),
equality("dim1", "nonexistent",
ColumnType.STRING)
),
new FilteredAggregatorFactory(
@@ -13533,6 +13533,24 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testStringAnyAggArgValidation()
+ {
+ DruidException e = assertThrows(DruidException.class, () -> testBuilder()
+ .sql("SELECT ANY_VALUE(dim3, 1000, 'true') FROM foo")
+ .queryContext(ImmutableMap.of())
+ .run());
+ assertThat(e, invalidSqlIs(
+ "Cannot apply 'ANY_VALUE' to arguments of type 'ANY_VALUE(<VARCHAR>,
<INTEGER>, <CHAR(4)>)'. Supported form(s): 'ANY_VALUE(<expr>,
[<maxBytesPerStringInt>, [<aggregateMultipleValuesBoolean>]])' (line [1],
column [8])"));
+ DruidException e1 = assertThrows(DruidException.class, () -> testBuilder()
+ .sql("SELECT ANY_VALUE(dim3, 1000, null) FROM foo")
+ .queryContext(ImmutableMap.of()).run());
+ Assert.assertEquals("Illegal use of 'NULL' (line [1], column [30])",
e1.getMessage());
+ DruidException e2 = assertThrows(DruidException.class, () -> testBuilder()
+ .sql("SELECT ANY_VALUE(dim3, null, true) FROM foo")
+ .queryContext(ImmutableMap.of()).run());
+ Assert.assertEquals("Illegal use of 'NULL' (line [1], column [24])",
e2.getMessage());
+ }
@Test
public void testStringAggMaxBytes()
{
@@ -14367,7 +14385,7 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
new StringLastAggregatorFactory("a1", "dim1",
"__time", 1024),
new StringFirstAggregatorFactory("a2", "dim3",
"__time", 1024),
new StringFirstAggregatorFactory("a3", "dim1",
"__time", 1024),
- new StringAnyAggregatorFactory("a4", "dim3", 1024)))
+ new StringAnyAggregatorFactory("a4", "dim3", 1024,
true)))
.build()
),
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java
b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java
index bb76488ccc4..eb8e7e67da7 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java
@@ -323,4 +323,10 @@ public class QueryTestBuilder
return build().resultsOnly();
}
+ public boolean isDecoupledMode()
+ {
+ String mode = (String)
queryContext.getOrDefault(PlannerConfig.CTX_NATIVE_QUERY_SQL_PLANNING_MODE, "");
+ return
PlannerConfig.NATIVE_QUERY_SQL_PLANNING_MODE_DECOUPLED.equalsIgnoreCase(mode);
+ }
+
}
diff --git a/website/.spelling b/website/.spelling
index 002998c442c..14233798fef 100644
--- a/website/.spelling
+++ b/website/.spelling
@@ -2334,3 +2334,4 @@ LAST_VALUE
markUnused
markUsed
segmentId
+aggregateMultipleValues
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]