This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fe22b5  Implement ARRAYLENGTH UDF for multi-valued columns (#5301)
1fe22b5 is described below

commit 1fe22b5b61f2fbab0c537602fa7103ea896f832f
Author: Bo Zhang <[email protected]>
AuthorDate: Mon Apr 27 13:46:33 2020 +0800

    Implement ARRAYLENGTH UDF for multi-valued columns (#5301)
    
    * Implement LENGTH UDF for multi-valued columns
    
    * Add integration test for LENGTH UDF
    
    * Fix a typo
    
    * Rename Length UDF to ArrayLength
---
 docs/pql_examples.rst                              |   3 +
 .../common/function/TransformFunctionType.java     |   1 +
 .../function/ArrayLengthTransformFunction.java     | 112 +++++++++++++++++++++
 .../transform/function/BaseTransformFunction.java  |   2 +
 .../function/TransformFunctionFactory.java         |   1 +
 .../function/ArrayLengthTransformFunctionTest.java |  61 +++++++++++
 .../tests/OfflineClusterIntegrationTest.java       |  22 ++++
 7 files changed, 202 insertions(+)

diff --git a/docs/pql_examples.rst b/docs/pql_examples.rst
index 2558c58..9362944 100644
--- a/docs/pql_examples.rst
+++ b/docs/pql_examples.rst
@@ -265,6 +265,9 @@ Supported transform functions
    expressed as hours since UTC epoch (note that the output is not Los Angeles
    timezone)
 
+``ARRAYLENGTH``
+   Takes a multi-valued column and returns the length of the column
+
 ``VALUEIN``
    Takes at least 2 arguments, where the first argument is a multi-valued 
column, and the following arguments are constant values.
    The transform function will filter the value from the multi-valued column 
with the given constant values.
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 1c0cf15..cd795ec 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -42,6 +42,7 @@ public enum TransformFunctionType {
   TIMECONVERT("timeConvert"),
   DATETIMECONVERT("dateTimeConvert"),
   DATETRUNC("dateTrunc"),
+  ARRAYLENGTH("arrayLength"),
   VALUEIN("valueIn"),
   MAPVALUE("mapValue");
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunction.java
new file mode 100644
index 0000000..6d75d60
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunction.java
@@ -0,0 +1,112 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+
+
+/**
+ * The ArrayLengthTransformFunction class implements arrayLength function for 
multi-valued columns
+ *
+ * Sample queries:
+ * SELECT COUNT(*) FROM table WHERE arrayLength(mvColumn) > 2
+ * SELECT COUNT(*) FROM table GROUP BY arrayLength(mvColumn)
+ * SELECT MAX(arrayLength(mvColumn)) FROM table
+ */
+public class ArrayLengthTransformFunction extends BaseTransformFunction {
+  public static final String FUNCTION_NAME = "arrayLength";
+
+  private int[] _results;
+  private TransformFunction _argument;
+
+  @Override
+  public String getName() {
+    return FUNCTION_NAME;
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, DataSource> 
dataSourceMap) {
+    // Check that there is only 1 argument
+    if (arguments.size() != 1) {
+      throw new IllegalArgumentException("Exactly 1 argument is required for 
ARRAYLENGTH transform function");
+    }
+
+    // Check that the argument is a multi-valued column or transform function
+    TransformFunction firstArgument = arguments.get(0);
+    if (firstArgument instanceof LiteralTransformFunction || 
firstArgument.getResultMetadata().isSingleValue()) {
+      throw new IllegalArgumentException(
+          "The argument of ARRAYLENGTH transform function must be a 
multi-valued column or a transform function");
+    }
+    _argument = firstArgument;
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return INT_SV_NO_DICTIONARY_METADATA;
+  }
+
+  @Override
+  public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
+    if (_results == null) {
+      _results = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+
+    int numDocs = projectionBlock.getNumDocs();
+    switch (_argument.getResultMetadata().getDataType()) {
+      case INT:
+        int[][] intValuesMV = 
_argument.transformToIntValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          _results[i] = intValuesMV[i].length;
+        }
+        break;
+      case LONG:
+        long[][] longValuesMV = 
_argument.transformToLongValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          _results[i] = longValuesMV[i].length;
+        }
+        break;
+      case FLOAT:
+        float[][] floatValuesMV = 
_argument.transformToFloatValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          _results[i] = floatValuesMV[i].length;
+        }
+        break;
+      case DOUBLE:
+        double[][] doubleValuesMV = 
_argument.transformToDoubleValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          _results[i] = doubleValuesMV[i].length;
+        }
+        break;
+      case STRING:
+        String[][] stringValuesMV = 
_argument.transformToStringValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          _results[i] = stringValuesMV[i].length;
+        }
+        break;
+      default:
+        throw new IllegalStateException();
+    }
+    return _results;
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunction.java
index f9478d9..6e89347 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunction.java
@@ -31,6 +31,8 @@ import org.apache.pinot.core.util.ArrayCopyUtils;
  * Base class for transform function providing the default implementation for 
all data types.
  */
 public abstract class BaseTransformFunction implements TransformFunction {
+  protected static final TransformResultMetadata INT_SV_NO_DICTIONARY_METADATA 
=
+      new TransformResultMetadata(DataType.INT, true, false);
   protected static final TransformResultMetadata 
LONG_SV_NO_DICTIONARY_METADATA =
       new TransformResultMetadata(DataType.LONG, true, false);
   protected static final TransformResultMetadata 
DOUBLE_SV_NO_DICTIONARY_METADATA =
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 41faf89..0d940e9 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -68,6 +68,7 @@ public class TransformFunctionFactory {
           put(TransformFunctionType.TIMECONVERT.getName().toLowerCase(), 
TimeConversionTransformFunction.class);
           put(TransformFunctionType.DATETIMECONVERT.getName().toLowerCase(), 
DateTimeConversionTransformFunction.class);
           put(TransformFunctionType.DATETRUNC.getName().toLowerCase(), 
DateTruncTransformFunction.class);
+          put(TransformFunctionType.ARRAYLENGTH.getName().toLowerCase(), 
ArrayLengthTransformFunction.class);
           put(TransformFunctionType.VALUEIN.getName().toLowerCase(), 
ValueInTransformFunction.class);
           put(TransformFunctionType.MAPVALUE.getName().toLowerCase(), 
MapValueTransformFunction.class);
         }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunctionTest.java
new file mode 100644
index 0000000..c688d3e
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunctionTest.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
+import org.apache.pinot.core.query.exception.BadQueryRequestException;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.testng.Assert;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+
+public class ArrayLengthTransformFunctionTest extends 
BaseTransformFunctionTest {
+
+  @Test
+  public void testLengthTransformFunction() {
+    TransformExpressionTree expression =
+        
TransformExpressionTree.compileToExpressionTree(String.format("arrayLength(%s)",
 INT_MV_COLUMN));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    Assert.assertTrue(transformFunction instanceof 
ArrayLengthTransformFunction);
+    Assert.assertEquals(transformFunction.getName(), 
ArrayLengthTransformFunction.FUNCTION_NAME);
+    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
FieldSpec.DataType.INT);
+    Assert.assertTrue(transformFunction.getResultMetadata().isSingleValue());
+    Assert.assertFalse(transformFunction.getResultMetadata().hasDictionary());
+
+    int[] results = transformFunction.transformToIntValuesSV(_projectionBlock);
+    for (int i = 0; i < NUM_ROWS; i++) {
+      Assert.assertEquals(results[i], _intMVValues[i].length);
+    }
+  }
+
+  @Test(dataProvider = "testIllegalArguments", expectedExceptions = 
{BadQueryRequestException.class})
+  public void testIllegalArguments(String expressionStr) {
+    TransformExpressionTree expression = 
TransformExpressionTree.compileToExpressionTree(expressionStr);
+    TransformFunctionFactory.get(expression, _dataSourceMap);
+  }
+
+  @DataProvider(name = "testIllegalArguments")
+  public Object[][] testIllegalArguments() {
+    return new Object[][]{
+        new Object[]{String.format("arrayLength(%s,1)", INT_MV_COLUMN)},
+        new Object[]{"arrayLength(2)"},
+        new Object[]{String.format("arrayLength(%s)", LONG_SV_COLUMN)}};
+  }
+}
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index a98c8c3..5ccc832 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -566,6 +566,28 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
     assertEquals(groupByEntry.get("group").get(0).asDouble(), 16138.0 / 2);
     assertEquals(groupByResult.get("groupByColumns").get(0).asText(), 
"div(DaysSinceEpoch,'2')");
 
+    pqlQuery = "SELECT COUNT(*) FROM mytable GROUP BY 
arrayLength(DivAirports)";
+    response = postQuery(pqlQuery);
+    groupByResult = response.get("aggregationResults").get(0);
+    groupByEntry = groupByResult.get("groupByResult").get(0);
+    assertEquals(groupByEntry.get("value").asDouble(), 115545.0);
+    assertEquals(groupByEntry.get("group").get(0).asText(), "5");
+    assertEquals(groupByResult.get("groupByColumns").get(0).asText(), 
"arraylength(DivAirports)");
+
+    pqlQuery = "SELECT COUNT(*) FROM mytable GROUP BY 
arrayLength(valueIn(DivAirports,'DFW','ORD'))";
+    response = postQuery(pqlQuery);
+    groupByResult = response.get("aggregationResults").get(0);
+    groupByEntry = groupByResult.get("groupByResult").get(0);
+    assertEquals(groupByEntry.get("value").asDouble(), 114895.0);
+    assertEquals(groupByEntry.get("group").get(0).asText(), "0");
+    groupByEntry = groupByResult.get("groupByResult").get(1);
+    assertEquals(groupByEntry.get("value").asDouble(), 648.0);
+    assertEquals(groupByEntry.get("group").get(0).asText(), "1");
+    groupByEntry = groupByResult.get("groupByResult").get(2);
+    assertEquals(groupByEntry.get("value").asDouble(), 2.0);
+    assertEquals(groupByEntry.get("group").get(0).asText(), "2");
+    assertEquals(groupByResult.get("groupByColumns").get(0).asText(), 
"arraylength(valuein(DivAirports,'DFW','ORD'))");
+
     pqlQuery = "SELECT COUNT(*) FROM mytable GROUP BY 
valueIn(DivAirports,'DFW','ORD')";
     response = postQuery(pqlQuery);
     groupByResult = response.get("aggregationResults").get(0);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to