This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch fixing_output_schema_for_aggregation_groupbys
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git

commit 3a3d70ca78aa700a7af5ed0a177dbd87c03011ed
Author: Xiang Fu <[email protected]>
AuthorDate: Mon Jan 27 05:03:23 2020 -0800

    Make output schema to match selection list for aggregation groupbys
---
 .../pinot/common/utils/request/RequestUtils.java   |  29 ++++++
 .../parsers/PinotQuery2BrokerRequestConverter.java |   3 +-
 .../apache/pinot/sql/parsers/CalciteSqlParser.java |  27 +++--
 .../pinot/sql/parsers/CalciteSqlCompilerTest.java  |  13 ++-
 .../core/query/reduce/GroupByDataTableReducer.java | 112 ++++++++++++++++++++-
 .../tests/BaseClusterIntegrationTestSet.java       |   4 +
 6 files changed, 172 insertions(+), 16 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
index ba16f94..0a3a71d 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
@@ -321,4 +321,33 @@ public class RequestUtils {
       throw new IllegalStateException("Cannot get expression from " + 
astNode.getClass().getSimpleName());
     }
   }
+
+  public static String prettyPrint(Expression expression) {
+    if (expression == null) {
+      return "null";
+    }
+    if (expression.getIdentifier() != null) {
+      return expression.getIdentifier().getName();
+    }
+    if (expression.getLiteral() != null) {
+      if (expression.getLiteral().isSetLongValue()) {
+        return Long.toString(expression.getLiteral().getLongValue());
+      }
+    }
+    if (expression.getFunctionCall() != null) {
+      String res = expression.getFunctionCall().getOperator() + "(";
+      boolean isFirstParam = true;
+      for (Expression operand : expression.getFunctionCall().getOperands()) {
+        res += prettyPrint(operand);
+        if (!isFirstParam) {
+          res += ", ";
+        } else {
+          isFirstParam = false;
+        }
+      }
+      res += ")";
+      return res;
+    }
+    return null;
+  }
 }
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
 
b/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
index cf4f6a3..9adcb1f 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
@@ -71,10 +71,11 @@ public class PinotQuery2BrokerRequestConverter {
 
     //TODO: these should not be part of the query?
     //brokerRequest.setEnableTrace();
-    //brokerRequest.setDebugOptions();
+    brokerRequest.setDebugOptions(pinotQuery.getDebugOptions());
     brokerRequest.setQueryOptions(pinotQuery.getQueryOptions());
     //brokerRequest.setBucketHashKey();
     //brokerRequest.setDuration();
+    brokerRequest.setPinotQuery(pinotQuery);
 
     return brokerRequest;
   }
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
index 84d7777..cec5752 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
@@ -101,24 +101,29 @@ public class CalciteSqlParser {
 
   private static void validateGroupByClause(PinotQuery pinotQuery)
       throws SqlCompilationException {
-    if(pinotQuery.getGroupByList() == null) {
+    if (pinotQuery.getGroupByList() == null) {
       return;
     }
-    // Sanity check group by query: All identifiers in selection list should 
be also included in group by list.
-    Set<String> groupByIdentifiers = 
extractIdentifiers(pinotQuery.getGroupByList());
+    // Sanity check group by query: All non-aggregate expression in selection 
list should be also included in group by list.
     for (Expression selectExpression : pinotQuery.getSelectList()) {
-      if (selectExpression.getIdentifier() != null) {
-        String identifier = selectExpression.getIdentifier().getName();
-        if (!groupByIdentifiers.contains(identifier)) {
-          throw new SqlCompilationException("'" + identifier + "' should 
appear in GROUP BY clause.");
+      if (!isAggregateExpression(selectExpression)) {
+        boolean foundInGroupByClause = false;
+        for (Expression groupByExpression : pinotQuery.getGroupByList()) {
+          if (groupByExpression.equals(selectExpression)) {
+            foundInGroupByClause = true;
+          }
+        }
+        if (!foundInGroupByClause) {
+          throw new SqlCompilationException(
+              "'" + RequestUtils.prettyPrint(selectExpression) + "' should 
appear in GROUP BY clause.");
         }
       }
     }
     // Sanity check on group by clause shouldn't contain aggregate expression.
-    for (Expression selectExpression : pinotQuery.getGroupByList()) {
-      if (isAggregateExpression(selectExpression)) {
-        throw new SqlCompilationException(
-            "Aggregate expression '" + selectExpression + "' is not allowed in 
GROUP BY clause.");
+    for (Expression groupByExpression : pinotQuery.getGroupByList()) {
+      if (isAggregateExpression(groupByExpression)) {
+        throw new SqlCompilationException("Aggregate expression '" + 
RequestUtils.prettyPrint(groupByExpression)
+            + "' is not allowed in GROUP BY clause.");
       }
     }
   }
diff --git 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 383863d..6ae54e3 100644
--- 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -871,11 +871,22 @@ public class CalciteSqlCompilerTest {
 
     // Valid groupBy non-aggregate function should pass.
     sql =
-        "select secondsSinceEpoch, sum(rsvp_count), count(*) from meetupRsvp 
group by dateConvert(secondsSinceEpoch) limit 50";
+        "select dateConvert(secondsSinceEpoch), sum(rsvp_count), count(*) from 
meetupRsvp group by dateConvert(secondsSinceEpoch) limit 50";
     pinotQuery = CalciteSqlParser.compileToPinotQuery(sql);
     Assert.assertEquals(pinotQuery.getGroupByListSize(), 1);
     Assert.assertEquals(pinotQuery.getSelectListSize(), 3);
 
+    // Invalid: secondsSinceEpoch should be in groupBy clause.
+    try {
+      sql =
+          "select secondsSinceEpoch, dateConvert(secondsSinceEpoch), 
sum(rsvp_count), count(*) from meetupRsvp group by 
dateConvert(secondsSinceEpoch) limit 50";
+      CalciteSqlParser.compileToPinotQuery(sql);
+      Assert.fail("Query should have failed compilation");
+    } catch (Exception e) {
+      Assert.assertTrue(e instanceof SqlCompilationException);
+      Assert.assertTrue(e.getMessage().contains("'secondsSinceEpoch' should 
appear in GROUP BY clause."));
+    }
+
     // Invalid groupBy clause shouldn't contain aggregate expression, like 
sum(rsvp_count), count(*).
     try {
       sql =
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index cc6fd73..9131eae 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -32,6 +32,7 @@ import org.apache.pinot.common.metrics.BrokerMeter;
 import org.apache.pinot.common.metrics.BrokerMetrics;
 import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
+import org.apache.pinot.common.request.Expression;
 import org.apache.pinot.common.request.GroupBy;
 import org.apache.pinot.common.request.HavingFilterQuery;
 import org.apache.pinot.common.request.HavingFilterQueryMap;
@@ -42,6 +43,7 @@ import org.apache.pinot.common.response.broker.GroupByResult;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataTable;
+import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
 import org.apache.pinot.core.data.table.IndexedTable;
 import org.apache.pinot.core.data.table.Record;
@@ -53,12 +55,15 @@ import 
org.apache.pinot.core.transport.ServerRoutingInstance;
 import org.apache.pinot.core.util.GroupByUtils;
 import org.apache.pinot.core.util.QueryOptions;
 import org.apache.pinot.spi.utils.BytesUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
  * Helper class to reduce data tables and set group by results into the 
BrokerResponseNative
  */
 public class GroupByDataTableReducer implements DataTableReducer {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(GroupByDataTableReducer.class);
 
   private final BrokerRequest _brokerRequest;
   private final AggregationFunction[] _aggregationFunctions;
@@ -72,6 +77,8 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
   private final boolean _preserveType;
   private final boolean _groupByModeSql;
   private final boolean _responseFormatSql;
+  private final List<Expression> _sqlSelectionList;
+  private final List<Expression> _groupByList;
 
   GroupByDataTableReducer(BrokerRequest brokerRequest, AggregationFunction[] 
aggregationFunctions,
       QueryOptions queryOptions) {
@@ -87,6 +94,13 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     _preserveType = queryOptions.isPreserveType();
     _groupByModeSql = queryOptions.isGroupByModeSQL();
     _responseFormatSql = queryOptions.isResponseFormatSQL();
+    if (_responseFormatSql && brokerRequest.getPinotQuery() != null) {
+      _sqlSelectionList = brokerRequest.getPinotQuery().getSelectList();
+      _groupByList = brokerRequest.getPinotQuery().getGroupByList();
+    } else {
+      _sqlSelectionList = null;
+      _groupByList = null;
+    }
   }
 
   /**
@@ -161,7 +175,6 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
           resultSize = 
brokerResponseNative.getAggregationResults().get(0).getGroupByResult().size();
         }
       }
-
     }
 
     if (brokerMetrics != null && resultSize > 0) {
@@ -180,11 +193,14 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
 
     IndexedTable indexedTable = getIndexedTable(dataSchema, dataTables);
 
+    int[] finalSchemaMapIdx = null;
+    if (_sqlSelectionList != null) {
+      finalSchemaMapIdx = getFinalSchemaMapIdx(dataSchema);
+    }
     List<Object[]> rows = new ArrayList<>();
     Iterator<Record> sortedIterator = indexedTable.iterator();
     int numRows = 0;
     while (numRows < _groupBy.getTopN() && sortedIterator.hasNext()) {
-
       Record nextRecord = sortedIterator.next();
       Object[] values = nextRecord.getValues();
 
@@ -194,15 +210,105 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
         values[index] = 
_aggregationFunctions[aggNum++].extractFinalResult(values[index]);
         index++;
       }
-      rows.add(values);
+      if (_sqlSelectionList != null) {
+        Object[] finalValues = new Object[_sqlSelectionList.size()];
+        for (int i = 0; i < finalSchemaMapIdx.length; i++) {
+          if (finalSchemaMapIdx[i] == -1) {
+            finalValues[i] = null;
+          } else {
+            finalValues[i] = values[finalSchemaMapIdx[i]];
+          }
+        }
+        rows.add(finalValues);
+      } else {
+        rows.add(values);
+      }
       numRows++;
     }
 
     DataSchema finalDataSchema = getSQLResultTableSchema(dataSchema);
+    if (_sqlSelectionList != null) {
+      int columnSize = _sqlSelectionList.size();
+      String[] columns = new String[columnSize];
+      DataSchema.ColumnDataType[] columnDataTypes = new 
DataSchema.ColumnDataType[columnSize];
+      for (int i = 0; i < columnSize; i++) {
+        if (finalSchemaMapIdx[i] == -1) {
+          columns[i] = RequestUtils.prettyPrint(_sqlSelectionList.get(i));
+          columnDataTypes[i] = DataSchema.ColumnDataType.STRING;
+        } else {
+          columns[i] = finalDataSchema.getColumnName(finalSchemaMapIdx[i]);
+          columnDataTypes[i] = 
finalDataSchema.getColumnDataType(finalSchemaMapIdx[i]);
+        }
+      }
+      finalDataSchema = new DataSchema(columns, columnDataTypes);
+    }
     brokerResponseNative.setResultTable(new ResultTable(finalDataSchema, 
rows));
   }
 
   /**
+   * Generate index mapping based on selection expression to DataTable schema, 
which is groupBy columns,
+   * then aggregation functions.
+   * @param dataSchema
+   *
+   * @return a mapping from final schema idx to corresponding idx in data 
table schema.
+   */
+  private int[] getFinalSchemaMapIdx(DataSchema dataSchema) {
+    int[] finalSchemaMapIdx = new int[_sqlSelectionList.size()];
+    int nextAggregationIdx = _numGroupBy;
+    for (int i = 0; i < _sqlSelectionList.size(); i++) {
+      finalSchemaMapIdx[i] = getExpressionMapIdx(dataSchema, 
_sqlSelectionList.get(i), nextAggregationIdx);
+      if (finalSchemaMapIdx[i] == nextAggregationIdx) {
+        nextAggregationIdx++;
+      }
+    }
+    return finalSchemaMapIdx;
+  }
+
+  private int getExpressionMapIdx(DataSchema dataSchema, Expression 
expression, int nextAggregationIdx) {
+    // Check if expression matches groupBy list.
+    int idxFromGroupByList = getGroupByIdx(_groupByList, expression);
+    if (idxFromGroupByList != -1) {
+      return idxFromGroupByList;
+    }
+    // Handle all functions
+    if (expression.getFunctionCall() != null) {
+      // handle AS
+      if (expression.getFunctionCall().getOperator().equalsIgnoreCase("AS")) {
+        return getExpressionMapIdx(dataSchema, 
expression.getFunctionCall().getOperands().get(0), nextAggregationIdx);
+      }
+      // Return next aggregation idx.
+      return nextAggregationIdx;
+    }
+    // Handle identifier, which is a column.
+    if (expression.getIdentifier() != null) {
+      String columnName = expression.getIdentifier().getName();
+      for (int i = 0; i < dataSchema.size(); i++) {
+        if (columnName.equalsIgnoreCase(dataSchema.getColumnName(i))) {
+          return i;
+        }
+      }
+    }
+    return -1;
+  }
+
+  /**
+   * Trying to match an expression based on given groupByList.
+   *
+   * @param groupByList
+   * @param expression
+   * @return matched idx from groupByList
+   */
+  private int getGroupByIdx(List<Expression> groupByList, Expression 
expression) {
+    for (int i = 0; i < groupByList.size(); i++) {
+      Expression groupByExpr = groupByList.get(i);
+      if (groupByExpr.equals(expression)) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  /**
    * Constructs the final result table schema for sql mode execution
    * The data type for the aggregations needs to be taken from the final 
result's data type
    */
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
index 191d65a..3c5e917 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
@@ -172,6 +172,10 @@ public abstract class BaseClusterIntegrationTestSet 
extends BaseClusterIntegrati
     testSqlQuery(query, Collections.singletonList(query));
     query = "SELECT COUNT(*), MAX(ArrTime), MIN(ArrTime) FROM mytable WHERE 
DaysSinceEpoch >= 16312";
     testSqlQuery(query, Collections.singletonList(query));
+    query = "SELECT COUNT(*), MAX(ArrTime), MIN(ArrTime), DaysSinceEpoch FROM 
mytable GROUP BY DaysSinceEpoch";
+    testSqlQuery(query, Collections.singletonList(query));
+    query = "SELECT DaysSinceEpoch, COUNT(*), MAX(ArrTime), MIN(ArrTime) FROM 
mytable GROUP BY DaysSinceEpoch";
+    testSqlQuery(query, Collections.singletonList(query));
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to