This is an automated email from the ASF dual-hosted git repository.
suneet pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new c68388ebcd Vectorized version of string last aggregator (#12493)
c68388ebcd is described below
commit c68388ebcd6cc2a77d2f6c41320906b32a5f6028
Author: somu-imply <[email protected]>
AuthorDate: Mon May 9 17:02:38 2022 -0700
Vectorized version of string last aggregator (#12493)
* Vectorized version of string last aggregator
* Updating string last and adding testcases
* Updating code and adding testcases for serializable pairs
* Addressing review comments
---
.../aggregation/first/StringFirstLastUtils.java | 43 +++++
.../last/StringLastAggregatorFactory.java | 28 +++
.../last/StringLastVectorAggregator.java | 190 +++++++++++++++++++++
.../last/StringLastVectorAggregatorTest.java | 167 ++++++++++++++++++
.../apache/druid/sql/calcite/CalciteQueryTest.java | 39 ++++-
5 files changed, 462 insertions(+), 5 deletions(-)
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
index 323ce413e6..6b93be7d70 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
@@ -27,6 +27,8 @@ import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorObjectSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -59,6 +61,47 @@ public class StringFirstLastUtils
|| SerializablePairLongString.class.isAssignableFrom(clazz);
}
+ /**
+ * Returns whether an object *might* contain SerializablePairLongString
objects.
+ */
+ public static boolean objectNeedsFoldCheck(Object obj)
+ {
+ if (obj == null) {
+ return false;
+ }
+ final Class<?> clazz = obj.getClass();
+ return clazz.isAssignableFrom(SerializablePairLongString.class)
+ || SerializablePairLongString.class.isAssignableFrom(clazz);
+ }
+
+ /**
+ * Return the object at a particular index from the vector selectors.
+ * index of bounds issues is the responsibility of the caller
+ */
+ public static SerializablePairLongString readPairFromVectorSelectorsAtIndex(
+ BaseLongVectorValueSelector timeSelector,
+ VectorObjectSelector valueSelector,
+ int index
+ )
+ {
+ final long time;
+ final String string;
+ final Object object = valueSelector.getObjectVector()[index];
+ if (object instanceof SerializablePairLongString) {
+ final SerializablePairLongString pair = (SerializablePairLongString)
object;
+ time = pair.lhs;
+ string = pair.rhs;
+ } else if (object != null) {
+ time = timeSelector.getLongVector()[index];
+ string = DimensionHandlerUtils.convertObjectToString(object);
+ } else {
+ // Don't aggregate nulls.
+ return null;
+ }
+
+ return new SerializablePairLongString(time, string);
+ }
+
@Nullable
public static SerializablePairLongString readPairFromSelectors(
final BaseLongColumnValueSelector timeSelector,
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
index 71bf66e608..39c5b29647 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
@@ -31,14 +31,20 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.SerializablePairLongString;
+import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.StringFirstLastUtils;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
+import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.NilColumnValueSelector;
+import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.segment.vector.VectorObjectSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -141,6 +147,28 @@ public class StringLastAggregatorFactory extends
AggregatorFactory
}
}
+ @Override
+ public boolean canVectorize(ColumnInspector columnInspector)
+ {
+ return true;
+ }
+
+ @Override
+ public VectorAggregator factorizeVector(VectorColumnSelectorFactory
selectorFactory)
+ {
+
+ ColumnCapabilities capabilities =
selectorFactory.getColumnCapabilities(fieldName);
+ VectorObjectSelector vSelector =
selectorFactory.makeObjectSelector(fieldName);
+ BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector)
selectorFactory.makeValueSelector(
+ timeColumn);
+ if (capabilities != null) {
+ return new StringLastVectorAggregator(timeSelector, vSelector,
maxStringBytes);
+ } else {
+ return new StringLastVectorAggregator(null, vSelector, maxStringBytes);
+ }
+
+ }
+
@Override
public Comparator getComparator()
{
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java
new file mode 100644
index 0000000000..045360ba61
--- /dev/null
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.last;
+
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.query.aggregation.SerializablePairLongString;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.query.aggregation.first.StringFirstLastUtils;
+import org.apache.druid.segment.DimensionHandlerUtils;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorObjectSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+public class StringLastVectorAggregator implements VectorAggregator
+{
+ private static final SerializablePairLongString INIT = new
SerializablePairLongString(
+ DateTimes.MIN.getMillis(),
+ null
+ );
+ private final BaseLongVectorValueSelector timeSelector;
+ private final VectorObjectSelector valueSelector;
+ private final int maxStringBytes;
+ protected long lastTime;
+
+ public StringLastVectorAggregator(
+ final BaseLongVectorValueSelector timeSelector,
+ final VectorObjectSelector valueSelector,
+ final int maxStringBytes
+ )
+ {
+ this.timeSelector = timeSelector;
+ this.valueSelector = valueSelector;
+ this.maxStringBytes = maxStringBytes;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ StringFirstLastUtils.writePair(buf, position, INIT, maxStringBytes);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ if (timeSelector == null) {
+ return;
+ }
+ long[] times = timeSelector.getLongVector();
+ Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
+
+ lastTime = buf.getLong(position);
+ int index;
+ for (int i = endRow - 1; i >= startRow; i--) {
+ if (objectsWhichMightBeStrings[i] == null) {
+ continue;
+ }
+ if (times[i] < lastTime) {
+ break;
+ }
+ index = i;
+ final boolean foldNeeded =
StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]);
+ if (foldNeeded) {
+ // Less efficient code path when folding is a possibility (we must
read the value selector first just in case
+ // it's a foldable object).
+ final SerializablePairLongString inPair =
StringFirstLastUtils.readPairFromVectorSelectorsAtIndex(
+ timeSelector,
+ valueSelector,
+ index
+ );
+ if (inPair != null) {
+ final long lastTime = buf.getLong(position);
+ if (inPair.lhs >= lastTime) {
+ StringFirstLastUtils.writePair(
+ buf,
+ position,
+ new SerializablePairLongString(inPair.lhs, inPair.rhs),
+ maxStringBytes
+ );
+ }
+ }
+ } else {
+ final long time = times[index];
+
+ if (time >= lastTime) {
+ final String value =
DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[index]);
+ lastTime = time;
+ StringFirstLastUtils.writePair(
+ buf,
+ position,
+ new SerializablePairLongString(time, value),
+ maxStringBytes
+ );
+ }
+ }
+ }
+
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ long[] timeVector = timeSelector.getLongVector();
+ Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
+
+ // iterate once over the object vector to find first non null element and
+ // determine if the type is Pair or not
+ boolean foldNeeded = false;
+ for (Object obj : objectsWhichMightBeStrings) {
+ if (obj == null) {
+ continue;
+ } else {
+ foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj);
+ break;
+ }
+ }
+
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows == null ? i : rows[i];
+ long lastTime = buf.getLong(position);
+ if (timeVector[row] >= lastTime) {
+ if (foldNeeded) {
+ final SerializablePairLongString inPair =
StringFirstLastUtils.readPairFromVectorSelectorsAtIndex(
+ timeSelector,
+ valueSelector,
+ row
+ );
+ if (inPair != null) {
+ if (inPair.lhs >= lastTime) {
+ StringFirstLastUtils.writePair(
+ buf,
+ position,
+ new SerializablePairLongString(inPair.lhs, inPair.rhs),
+ maxStringBytes
+ );
+ }
+ }
+ } else {
+ final String value =
DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[row]);
+ lastTime = timeVector[row];
+ StringFirstLastUtils.writePair(
+ buf,
+ position,
+ new SerializablePairLongString(lastTime, value),
+ maxStringBytes
+ );
+ }
+ }
+ }
+ }
+
+ @Nullable
+ @Override
+ public Object get(ByteBuffer buf, int position)
+ {
+ return StringFirstLastUtils.readPair(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // nothing to close
+ }
+}
+
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java
new file mode 100644
index 0000000000..428ff3e374
--- /dev/null
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.last;
+
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.query.aggregation.SerializablePairLongString;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.segment.vector.VectorObjectSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+
+@RunWith(MockitoJUnitRunner.class)
+public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
+{
+ private static final double EPSILON = 1e-5;
+ private static final String[] VALUES = new String[]{"a", "b", null, "c"};
+ private static final boolean[] NULLS = new boolean[]{false, false, true,
false};
+ private static final String NAME = "NAME";
+ private static final String FIELD_NAME = "FIELD_NAME";
+ private static final String TIME_COL = "__time";
+ private long[] times = {2436, 6879, 7888, 8224};
+ private long[] timesSame = {2436, 2436};
+ private SerializablePairLongString[] pairs = {
+ new SerializablePairLongString(2345100L, "last"),
+ new SerializablePairLongString(2345001L, "notLast")
+ };
+
+ @Mock
+ private VectorObjectSelector selector;
+ @Mock
+ private VectorObjectSelector selectorForPairs;
+ @Mock
+ private BaseLongVectorValueSelector timeSelector;
+ @Mock
+ private BaseLongVectorValueSelector timeSelectorForPairs;
+ private ByteBuffer buf;
+ private StringLastVectorAggregator target;
+ private StringLastVectorAggregator targetWithPairs;
+
+ private StringLastAggregatorFactory stringLastAggregatorFactory;
+ @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+ private VectorColumnSelectorFactory selectorFactory;
+
+ @Before
+ public void setup()
+ {
+ byte[] randomBytes = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(randomBytes);
+ buf = ByteBuffer.wrap(randomBytes);
+ Mockito.doReturn(VALUES).when(selector).getObjectVector();
+ Mockito.doReturn(times).when(timeSelector).getLongVector();
+ Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector();
+ Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector();
+ target = new StringLastVectorAggregator(timeSelector, selector, 10);
+ targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs,
selectorForPairs, 10);
+ clearBufferForPositions(0, 0);
+
+
+
Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME);
+
Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL);
+ stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME,
FIELD_NAME, TIME_COL, 10);
+
+ }
+
+ @Test
+ public void testAggregateWithPairs()
+ {
+ targetWithPairs.aggregate(buf, 0, 0, pairs.length);
+ Pair<Long, String> result = (Pair<Long, String>) targetWithPairs.get(buf,
0);
+ //Should come 0 as the last value as the left of the pair is greater
+ Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue());
+ Assert.assertEquals(pairs[0].rhs, result.rhs);
+ }
+
+ @Test
+ public void testFactory()
+ {
+
Assert.assertTrue(stringLastAggregatorFactory.canVectorize(selectorFactory));
+ VectorAggregator vectorAggregator =
stringLastAggregatorFactory.factorizeVector(selectorFactory);
+ Assert.assertNotNull(vectorAggregator);
+ Assert.assertEquals(StringLastVectorAggregator.class,
vectorAggregator.getClass());
+ }
+
+ @Test
+ public void initValueShouldBeMinDate()
+ {
+ target.init(buf, 0);
+ long initVal = buf.getLong(0);
+ Assert.assertEquals(DateTimes.MIN.getMillis(), initVal);
+ }
+
+ @Test
+ public void aggregate()
+ {
+ target.aggregate(buf, 0, 0, VALUES.length);
+ Pair<Long, String> result = (Pair<Long, String>) target.get(buf, 0);
+ Assert.assertEquals(times[3], result.lhs.longValue());
+ Assert.assertEquals(VALUES[3], result.rhs);
+ }
+
+ @Test
+ public void aggregateBatchWithoutRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, null, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ Pair<Long, String> result = (Pair<Long, String>) target.get(buf,
positions[i] + positionOffset);
+ Assert.assertEquals(times[i], result.lhs.longValue());
+ Assert.assertEquals(VALUES[i], result.rhs);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ Pair<Long, String> result = (Pair<Long, String>) target.get(buf,
positions[i] + positionOffset);
+ Assert.assertEquals(times[rows[i]], result.lhs.longValue());
+ Assert.assertEquals(VALUES[rows[i]], result.rhs);
+ }
+ }
+
+ private void clearBufferForPositions(int offset, int... positions)
+ {
+ for (int position : positions) {
+ target.init(buf, offset + position);
+ }
+ }
+}
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 9feb822679..1af59214f8 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -683,8 +683,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testLatestAggregators() throws Exception
{
- // Cannot vectorize until StringLast is vectorized
- skipVectorize();
+
testQuery(
"SELECT "
+ "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), "
@@ -944,6 +943,39 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
+ @Test
+ public void testStringLatestGroupBy() throws Exception
+ {
+ testQuery(
+ "SELECT dim2, LATEST(dim4,10) AS val1 FROM druid.numfoo GROUP BY dim2",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE3)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(dimensions(new
DefaultDimensionSpec("dim2", "_d0")))
+ .setAggregatorSpecs(aggregators(
+ new
StringLastAggregatorFactory("a0", "dim4", null, 10)
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ NullHandling.sqlCompatible()
+ ? ImmutableList.of(
+ new Object[]{null, "b"},
+ new Object[]{"", "a"},
+ new Object[]{"a", "b"},
+ new Object[]{"abc", "b"}
+ )
+ : ImmutableList.of(
+ new Object[]{"", "b"},
+ new Object[]{"a", "b"},
+ new Object[]{"abc", "b"}
+ )
+ );
+ }
+
// This test the off-heap (buffer) version of the EarliestAggregator
(Double/Float/Long)
@Test
public void testPrimitiveEarliestInSubquery() throws Exception
@@ -999,9 +1031,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testStringLatestInSubquery() throws Exception
{
- // Cannot vectorize LATEST aggregator for Strings
- skipVectorize();
-
testQuery(
"SELECT SUM(val) FROM (SELECT dim2, LATEST(dim1, 10) AS val FROM foo
GROUP BY dim2)",
ImmutableList.of(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]