This is an automated email from the ASF dual-hosted git repository.

jonwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new c36f12f1d8 Support complex variance object inputs for variance SQL agg 
function (#14463)
c36f12f1d8 is described below

commit c36f12f1d8bf81cf110dd41853627be175816b00
Author: Jonathan Wei <[email protected]>
AuthorDate: Wed Jun 28 13:14:19 2023 -0500

    Support complex variance object inputs for variance SQL agg function 
(#14463)
    
    * Support complex variance object inputs for variance SQL agg function
    
    * Add test
    
    * Include complexTypeChecker, address PR comments
    
    * Checkstyle, javadoc link
---
 .../variance/VarianceAggregatorFactory.java        |  2 +-
 .../variance/sql/BaseVarianceSqlAggregator.java    | 67 ++++++++++++++----
 .../variance/sql/VarianceSqlAggregatorTest.java    | 58 ++++++++++++++-
 .../druid/sql/calcite/table/RowSignatures.java     | 82 ++++++++++++++++++++--
 4 files changed, 188 insertions(+), 21 deletions(-)

diff --git 
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
 
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
index 47eccfbffd..40d06bbbe0 100644
--- 
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
+++ 
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
@@ -60,7 +60,7 @@ import java.util.Objects;
 @JsonTypeName("variance")
 public class VarianceAggregatorFactory extends AggregatorFactory
 {
-  private static final String VARIANCE_TYPE_NAME = "variance";
+  public static final String VARIANCE_TYPE_NAME = "variance";
   public static final ColumnType TYPE = 
ColumnType.ofComplex(VARIANCE_TYPE_NAME);
 
   protected final String fieldName;
diff --git 
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
 
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
index 3eb3f49816..0b1562eb83 100644
--- 
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
+++ 
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
@@ -26,7 +26,10 @@ import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.SqlFunctionCategory;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.query.aggregation.AggregatorFactory;
@@ -42,15 +45,33 @@ import 
org.apache.druid.sql.calcite.aggregation.Aggregations;
 import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
 import org.apache.druid.sql.calcite.expression.DruidExpression;
 import org.apache.druid.sql.calcite.expression.Expressions;
+import org.apache.druid.sql.calcite.expression.OperatorConversions;
 import org.apache.druid.sql.calcite.planner.Calcites;
 import org.apache.druid.sql.calcite.planner.PlannerContext;
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
+import org.apache.druid.sql.calcite.table.RowSignatures;
 
 import javax.annotation.Nullable;
 import java.util.List;
 
 public abstract class BaseVarianceSqlAggregator implements SqlAggregator
 {
+  private static final String VARIANCE_NAME = "VARIANCE";
+  private static final String STDDEV_NAME = "STDDEV";
+
+  private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
+      buildSqlAvgAggFunction(VARIANCE_NAME);
+  private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
+      buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
+  private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
+      buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
+  private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
+      buildSqlAvgAggFunction(STDDEV_NAME);
+  private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
+      buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
+  private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
+      buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
+
   @Nullable
   @Override
   public Aggregation toDruidAggregation(
@@ -104,12 +125,13 @@ public abstract class BaseVarianceSqlAggregator 
implements SqlAggregator
 
     if (inputType.isNumeric()) {
       inputTypeName = StringUtils.toLowerCase(inputType.getType().name());
+    } else if (inputType.equals(VarianceAggregatorFactory.TYPE)) {
+      inputTypeName = VarianceAggregatorFactory.VARIANCE_TYPE_NAME;
     } else {
       throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", 
func, inputType.asTypeString());
     }
 
-
-    if (func == SqlStdOperatorTable.VAR_POP || func == 
SqlStdOperatorTable.STDDEV_POP) {
+    if (func.getName().equals(SqlKind.VAR_POP.name()) || 
func.getName().equals(SqlKind.STDDEV_POP.name())) {
       estimator = "population";
     } else {
       estimator = "sample";
@@ -122,9 +144,9 @@ public abstract class BaseVarianceSqlAggregator implements 
SqlAggregator
         inputTypeName
     );
 
-    if (func == SqlStdOperatorTable.STDDEV_POP
-        || func == SqlStdOperatorTable.STDDEV_SAMP
-        || func == SqlStdOperatorTable.STDDEV) {
+    if (func.getName().equals(STDDEV_NAME)
+        || func.getName().equals(SqlKind.STDDEV_POP.name())
+        || func.getName().equals(SqlKind.STDDEV_SAMP.name())) {
       postAggregator = new StandardDeviationPostAggregator(
           name,
           aggregatorFactory.getName(),
@@ -137,21 +159,40 @@ public abstract class BaseVarianceSqlAggregator 
implements SqlAggregator
     );
   }
 
+  /**
+   * Creates a {@link SqlAggFunction} that is the same as {@link 
org.apache.calcite.sql.fun.SqlAvgAggFunction}
+   * but with an operand type that accepts variance aggregator objects in 
addition to numeric inputs.
+   */
+  private static SqlAggFunction buildSqlAvgAggFunction(String name)
+  {
+    return OperatorConversions
+        .aggregatorBuilder(name)
+        .returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
+        .operandTypeChecker(
+            OperandTypes.or(
+                OperandTypes.NUMERIC,
+                
RowSignatures.complexTypeChecker(VarianceAggregatorFactory.TYPE)
+            )
+        )
+        .functionCategory(SqlFunctionCategory.NUMERIC)
+        .build();
+  }
+
   public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
   {
     @Override
     public SqlAggFunction calciteFunction()
     {
-      return SqlStdOperatorTable.VAR_POP;
+      return VARIANCE_POP_SQL_AGG_FUNC_INSTANCE;
     }
   }
-  
+
   public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
   {
     @Override
     public SqlAggFunction calciteFunction()
     {
-      return SqlStdOperatorTable.VAR_SAMP;
+      return VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE;
     }
   }
 
@@ -160,7 +201,7 @@ public abstract class BaseVarianceSqlAggregator implements 
SqlAggregator
     @Override
     public SqlAggFunction calciteFunction()
     {
-      return SqlStdOperatorTable.VARIANCE;
+      return VARIANCE_SQL_AGG_FUNC_INSTANCE;
     }
   }
 
@@ -169,7 +210,7 @@ public abstract class BaseVarianceSqlAggregator implements 
SqlAggregator
     @Override
     public SqlAggFunction calciteFunction()
     {
-      return SqlStdOperatorTable.STDDEV_POP;
+      return STDDEV_POP_SQL_AGG_FUNC_INSTANCE;
     }
   }
 
@@ -178,7 +219,7 @@ public abstract class BaseVarianceSqlAggregator implements 
SqlAggregator
     @Override
     public SqlAggFunction calciteFunction()
     {
-      return SqlStdOperatorTable.STDDEV_SAMP;
+      return STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE;
     }
   }
 
@@ -187,7 +228,7 @@ public abstract class BaseVarianceSqlAggregator implements 
SqlAggregator
     @Override
     public SqlAggFunction calciteFunction()
     {
-      return SqlStdOperatorTable.STDDEV;
+      return STDDEV_SQL_AGG_FUNC_INSTANCE;
     }
   }
 }
diff --git 
a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
 
b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
index bc1ef68169..5c496c4663 100644
--- 
a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
+++ 
b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
@@ -40,6 +40,7 @@ import 
org.apache.druid.query.aggregation.stats.DruidStatsModule;
 import 
org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
 import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
 import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
+import org.apache.druid.query.aggregation.variance.VarianceSerde;
 import org.apache.druid.query.dimension.DefaultDimensionSpec;
 import org.apache.druid.query.groupby.GroupByQuery;
 import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
@@ -51,6 +52,7 @@ import org.apache.druid.segment.QueryableIndex;
 import org.apache.druid.segment.column.ColumnType;
 import org.apache.druid.segment.incremental.IncrementalIndexSchema;
 import org.apache.druid.segment.join.JoinableFactoryWrapper;
+import org.apache.druid.segment.serde.ComplexMetrics;
 import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
 import 
org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
 import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
@@ -82,8 +84,10 @@ public class VarianceSqlAggregatorTest extends 
BaseCalciteQueryTest
       final Injector injector
   ) throws IOException
   {
+    ComplexMetrics.registerSerde(VarianceSerde.TYPE_NAME, new VarianceSerde());
+
     final QueryableIndex index =
-        IndexBuilder.create()
+        IndexBuilder.create(CalciteTests.getJsonMapper().registerModules(new 
DruidStatsModule().getJacksonModules()))
                     .tmpDir(temporaryFolder.newFolder())
                     
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
                     .schema(
@@ -100,7 +104,8 @@ public class VarianceSqlAggregatorTest extends 
BaseCalciteQueryTest
                             )
                             .withMetrics(
                                 new CountAggregatorFactory("cnt"),
-                                new DoubleSumAggregatorFactory("m1", "m1")
+                                new DoubleSumAggregatorFactory("m1", "m1"),
+                                new VarianceAggregatorFactory("var1", "m1", 
null, null)
                             )
                             .withRollup(false)
                             .build()
@@ -624,6 +629,55 @@ public class VarianceSqlAggregatorTest extends 
BaseCalciteQueryTest
     );
   }
 
+  @Test
+  public void testVarianceAggAsInput()
+  {
+    final List<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{
+            "3.5",
+            "2.9166666666666665",
+            "3.5",
+            "1.8708286933869707",
+            "1.707825127659933",
+            "1.8708286933869707"
+        }
+    );
+    testQuery(
+        "SELECT\n"
+        + "VARIANCE(var1),\n"
+        + "VAR_POP(var1),\n"
+        + "VAR_SAMP(var1),\n"
+        + "STDDEV(var1),\n"
+        + "STDDEV_POP(var1),\n"
+        + "STDDEV_SAMP(var1)\n"
+        + "FROM numfoo",
+        ImmutableList.of(
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(CalciteTests.DATASOURCE3)
+                  .intervals(new 
MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+                  .granularity(Granularities.ALL)
+                  .aggregators(
+                      ImmutableList.of(
+                          new VarianceAggregatorFactory("a0:agg", "var1", 
"sample", "variance"),
+                          new VarianceAggregatorFactory("a1:agg", "var1", 
"population", "variance"),
+                          new VarianceAggregatorFactory("a2:agg", "var1", 
"sample", "variance"),
+                          new VarianceAggregatorFactory("a3:agg", "var1", 
"sample", "variance"),
+                          new VarianceAggregatorFactory("a4:agg", "var1", 
"population", "variance"),
+                          new VarianceAggregatorFactory("a5:agg", "var1", 
"sample", "variance")
+                      )
+                  )
+                  .postAggregators(
+                      new StandardDeviationPostAggregator("a3", "a3:agg", 
"sample"),
+                      new StandardDeviationPostAggregator("a4", "a4:agg", 
"population"),
+                      new StandardDeviationPostAggregator("a5", "a5:agg", 
"sample")
+                  )
+                  .context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
+                  .build()
+        ),
+        expectedResults
+    );
+  }
+
   @Override
   public void assertResultsEquals(String sql, List<Object[]> expectedResults, 
List<Object[]> results)
   {
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java 
b/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java
index 32abe56ee8..87519c7537 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java
@@ -23,11 +23,18 @@ import com.google.common.base.Preconditions;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeComparability;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.sql.SqlCallBinding;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlOperandCountRange;
+import org.apache.calcite.sql.SqlOperator;
 import org.apache.calcite.sql.type.AbstractSqlType;
+import org.apache.calcite.sql.type.SqlOperandCountRanges;
+import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.query.ordering.StringComparator;
 import org.apache.druid.query.ordering.StringComparators;
 import org.apache.druid.segment.column.ColumnHolder;
@@ -79,7 +86,9 @@ public class RowSignatures
   {
     Preconditions.checkNotNull(simpleExtraction, "simpleExtraction");
     if (simpleExtraction.getExtractionFn() != null
-        || rowSignature.getColumnType(simpleExtraction.getColumn()).map(type 
-> type.is(ValueType.STRING)).orElse(false)) {
+        || rowSignature.getColumnType(simpleExtraction.getColumn())
+                       .map(type -> type.is(ValueType.STRING))
+                       .orElse(false)) {
       return StringComparators.LEXICOGRAPHIC;
     } else {
       return StringComparators.NUMERIC;
@@ -164,7 +173,7 @@ public class RowSignatures
    * Creates a {@link ComplexSqlType} using the supplied {@link 
RelDataTypeFactory} to ensure that the
    * {@link ComplexSqlType} is interned. This is important because Calcite 
checks that the references are equal
    * instead of the objects being equivalent.
-   *
+   * <p>
    * This method uses {@link 
RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) ensures that 
if the
    * type factory is a {@link 
org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed 
through
    * {@link 
org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which 
interns the type.
@@ -179,15 +188,15 @@ public class RowSignatures
 
   /**
    * Calcite {@link RelDataType} for Druid complex columns, to preserve 
complex type information.
-   *
+   * <p>
    * If using with other operations of a {@link RelDataTypeFactory}, consider 
wrapping the creation of this type in
    * {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) 
to ensure that if the type factory is a
    * {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type 
is passed through
    * {@link 
org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which 
interns the type.
-   *
+   * <p>
    * If {@link SqlTypeName} is going to be {@link SqlTypeName#OTHER} and a 
{@link RelDataTypeFactory} is available,
    * consider using {@link #makeComplexType(RelDataTypeFactory, ColumnType, 
boolean)}.
-   *
+   * <p>
    * This type does not work well with {@link 
org.apache.calcite.sql.type.ReturnTypes#explicit(RelDataType)}, which
    * will create new {@link RelDataType} using {@link SqlTypeName} during 
return type inference, so implementors of
    * {@link org.apache.druid.sql.calcite.expression.SqlOperatorConversion} 
should implement the
@@ -235,4 +244,67 @@ public class RowSignatures
       return columnType.asTypeString();
     }
   }
+
+  public static ComplexSqlSingleOperandTypeChecker 
complexTypeChecker(ColumnType complexType)
+  {
+    return new ComplexSqlSingleOperandTypeChecker(
+        new ComplexSqlType(SqlTypeName.OTHER, complexType, true)
+    );
+  }
+
+  public static final class ComplexSqlSingleOperandTypeChecker implements 
SqlSingleOperandTypeChecker
+  {
+    private final ComplexSqlType type;
+
+    public ComplexSqlSingleOperandTypeChecker(
+        ComplexSqlType type
+    )
+    {
+      this.type = type;
+    }
+
+    @Override
+    public boolean checkSingleOperandType(
+        SqlCallBinding callBinding,
+        SqlNode operand,
+        int iFormalOperand,
+        boolean throwOnFailure
+    )
+    {
+      return 
type.equals(callBinding.getValidator().deriveType(callBinding.getScope(), 
operand));
+    }
+
+    @Override
+    public boolean checkOperandTypes(SqlCallBinding callBinding, boolean 
throwOnFailure)
+    {
+      if (callBinding.getOperandCount() != 1) {
+        return false;
+      }
+      return checkSingleOperandType(callBinding, callBinding.operand(0), 0, 
throwOnFailure);
+    }
+
+    @Override
+    public SqlOperandCountRange getOperandCountRange()
+    {
+      return SqlOperandCountRanges.of(1);
+    }
+
+    @Override
+    public String getAllowedSignatures(SqlOperator op, String opName)
+    {
+      return StringUtils.format("'%s'(%s)", opName, type);
+    }
+
+    @Override
+    public Consistency getConsistency()
+    {
+      return Consistency.NONE;
+    }
+
+    @Override
+    public boolean isOptional(int i)
+    {
+      return false;
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to