This is an automated email from the ASF dual-hosted git repository.
jonwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new c36f12f1d8 Support complex variance object inputs for variance SQL agg
function (#14463)
c36f12f1d8 is described below
commit c36f12f1d8bf81cf110dd41853627be175816b00
Author: Jonathan Wei <[email protected]>
AuthorDate: Wed Jun 28 13:14:19 2023 -0500
Support complex variance object inputs for variance SQL agg function
(#14463)
* Support complex variance object inputs for variance SQL agg function
* Add test
* Include complexTypeChecker, address PR comments
* Checkstyle, javadoc link
---
.../variance/VarianceAggregatorFactory.java | 2 +-
.../variance/sql/BaseVarianceSqlAggregator.java | 67 ++++++++++++++----
.../variance/sql/VarianceSqlAggregatorTest.java | 58 ++++++++++++++-
.../druid/sql/calcite/table/RowSignatures.java | 82 ++++++++++++++++++++--
4 files changed, 188 insertions(+), 21 deletions(-)
diff --git
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
index 47eccfbffd..40d06bbbe0 100644
---
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
+++
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
@@ -60,7 +60,7 @@ import java.util.Objects;
@JsonTypeName("variance")
public class VarianceAggregatorFactory extends AggregatorFactory
{
- private static final String VARIANCE_TYPE_NAME = "variance";
+ public static final String VARIANCE_TYPE_NAME = "variance";
public static final ColumnType TYPE =
ColumnType.ofComplex(VARIANCE_TYPE_NAME);
protected final String fieldName;
diff --git
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
index 3eb3f49816..0b1562eb83 100644
---
a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
+++
b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
@@ -26,7 +26,10 @@ import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.SqlFunctionCategory;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
@@ -42,15 +45,33 @@ import
org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
+import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
+import org.apache.druid.sql.calcite.table.RowSignatures;
import javax.annotation.Nullable;
import java.util.List;
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
{
+ private static final String VARIANCE_NAME = "VARIANCE";
+ private static final String STDDEV_NAME = "STDDEV";
+
+ private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
+ buildSqlAvgAggFunction(VARIANCE_NAME);
+ private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
+ buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
+ private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
+ buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
+ private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
+ buildSqlAvgAggFunction(STDDEV_NAME);
+ private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
+ buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
+ private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
+ buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
+
@Nullable
@Override
public Aggregation toDruidAggregation(
@@ -104,12 +125,13 @@ public abstract class BaseVarianceSqlAggregator
implements SqlAggregator
if (inputType.isNumeric()) {
inputTypeName = StringUtils.toLowerCase(inputType.getType().name());
+ } else if (inputType.equals(VarianceAggregatorFactory.TYPE)) {
+ inputTypeName = VarianceAggregatorFactory.VARIANCE_TYPE_NAME;
} else {
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]",
func, inputType.asTypeString());
}
-
- if (func == SqlStdOperatorTable.VAR_POP || func ==
SqlStdOperatorTable.STDDEV_POP) {
+ if (func.getName().equals(SqlKind.VAR_POP.name()) ||
func.getName().equals(SqlKind.STDDEV_POP.name())) {
estimator = "population";
} else {
estimator = "sample";
@@ -122,9 +144,9 @@ public abstract class BaseVarianceSqlAggregator implements
SqlAggregator
inputTypeName
);
- if (func == SqlStdOperatorTable.STDDEV_POP
- || func == SqlStdOperatorTable.STDDEV_SAMP
- || func == SqlStdOperatorTable.STDDEV) {
+ if (func.getName().equals(STDDEV_NAME)
+ || func.getName().equals(SqlKind.STDDEV_POP.name())
+ || func.getName().equals(SqlKind.STDDEV_SAMP.name())) {
postAggregator = new StandardDeviationPostAggregator(
name,
aggregatorFactory.getName(),
@@ -137,21 +159,40 @@ public abstract class BaseVarianceSqlAggregator
implements SqlAggregator
);
}
+ /**
+ * Creates a {@link SqlAggFunction} that is the same as {@link
org.apache.calcite.sql.fun.SqlAvgAggFunction}
+ * but with an operand type that accepts variance aggregator objects in
addition to numeric inputs.
+ */
+ private static SqlAggFunction buildSqlAvgAggFunction(String name)
+ {
+ return OperatorConversions
+ .aggregatorBuilder(name)
+ .returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
+ .operandTypeChecker(
+ OperandTypes.or(
+ OperandTypes.NUMERIC,
+
RowSignatures.complexTypeChecker(VarianceAggregatorFactory.TYPE)
+ )
+ )
+ .functionCategory(SqlFunctionCategory.NUMERIC)
+ .build();
+ }
+
public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
- return SqlStdOperatorTable.VAR_POP;
+ return VARIANCE_POP_SQL_AGG_FUNC_INSTANCE;
}
}
-
+
public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
- return SqlStdOperatorTable.VAR_SAMP;
+ return VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE;
}
}
@@ -160,7 +201,7 @@ public abstract class BaseVarianceSqlAggregator implements
SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
- return SqlStdOperatorTable.VARIANCE;
+ return VARIANCE_SQL_AGG_FUNC_INSTANCE;
}
}
@@ -169,7 +210,7 @@ public abstract class BaseVarianceSqlAggregator implements
SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
- return SqlStdOperatorTable.STDDEV_POP;
+ return STDDEV_POP_SQL_AGG_FUNC_INSTANCE;
}
}
@@ -178,7 +219,7 @@ public abstract class BaseVarianceSqlAggregator implements
SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
- return SqlStdOperatorTable.STDDEV_SAMP;
+ return STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE;
}
}
@@ -187,7 +228,7 @@ public abstract class BaseVarianceSqlAggregator implements
SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
- return SqlStdOperatorTable.STDDEV;
+ return STDDEV_SQL_AGG_FUNC_INSTANCE;
}
}
}
diff --git
a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
index bc1ef68169..5c496c4663 100644
---
a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
+++
b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
@@ -40,6 +40,7 @@ import
org.apache.druid.query.aggregation.stats.DruidStatsModule;
import
org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
+import org.apache.druid.query.aggregation.variance.VarianceSerde;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
@@ -51,6 +52,7 @@ import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
+import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import
org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
@@ -82,8 +84,10 @@ public class VarianceSqlAggregatorTest extends
BaseCalciteQueryTest
final Injector injector
) throws IOException
{
+ ComplexMetrics.registerSerde(VarianceSerde.TYPE_NAME, new VarianceSerde());
+
final QueryableIndex index =
- IndexBuilder.create()
+ IndexBuilder.create(CalciteTests.getJsonMapper().registerModules(new
DruidStatsModule().getJacksonModules()))
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
@@ -100,7 +104,8 @@ public class VarianceSqlAggregatorTest extends
BaseCalciteQueryTest
)
.withMetrics(
new CountAggregatorFactory("cnt"),
- new DoubleSumAggregatorFactory("m1", "m1")
+ new DoubleSumAggregatorFactory("m1", "m1"),
+ new VarianceAggregatorFactory("var1", "m1",
null, null)
)
.withRollup(false)
.build()
@@ -624,6 +629,55 @@ public class VarianceSqlAggregatorTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testVarianceAggAsInput()
+ {
+ final List<Object[]> expectedResults = ImmutableList.of(
+ new Object[]{
+ "3.5",
+ "2.9166666666666665",
+ "3.5",
+ "1.8708286933869707",
+ "1.707825127659933",
+ "1.8708286933869707"
+ }
+ );
+ testQuery(
+ "SELECT\n"
+ + "VARIANCE(var1),\n"
+ + "VAR_POP(var1),\n"
+ + "VAR_SAMP(var1),\n"
+ + "STDDEV(var1),\n"
+ + "STDDEV_POP(var1),\n"
+ + "STDDEV_SAMP(var1)\n"
+ + "FROM numfoo",
+ ImmutableList.of(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new
MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "var1",
"sample", "variance"),
+ new VarianceAggregatorFactory("a1:agg", "var1",
"population", "variance"),
+ new VarianceAggregatorFactory("a2:agg", "var1",
"sample", "variance"),
+ new VarianceAggregatorFactory("a3:agg", "var1",
"sample", "variance"),
+ new VarianceAggregatorFactory("a4:agg", "var1",
"population", "variance"),
+ new VarianceAggregatorFactory("a5:agg", "var1",
"sample", "variance")
+ )
+ )
+ .postAggregators(
+ new StandardDeviationPostAggregator("a3", "a3:agg",
"sample"),
+ new StandardDeviationPostAggregator("a4", "a4:agg",
"population"),
+ new StandardDeviationPostAggregator("a5", "a5:agg",
"sample")
+ )
+ .context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ expectedResults
+ );
+ }
+
@Override
public void assertResultsEquals(String sql, List<Object[]> expectedResults,
List<Object[]> results)
{
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java
b/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java
index 32abe56ee8..87519c7537 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java
@@ -23,11 +23,18 @@ import com.google.common.base.Preconditions;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeComparability;
import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.sql.SqlCallBinding;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlOperandCountRange;
+import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.AbstractSqlType;
+import org.apache.calcite.sql.type.SqlOperandCountRanges;
+import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.ordering.StringComparator;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.column.ColumnHolder;
@@ -79,7 +86,9 @@ public class RowSignatures
{
Preconditions.checkNotNull(simpleExtraction, "simpleExtraction");
if (simpleExtraction.getExtractionFn() != null
- || rowSignature.getColumnType(simpleExtraction.getColumn()).map(type
-> type.is(ValueType.STRING)).orElse(false)) {
+ || rowSignature.getColumnType(simpleExtraction.getColumn())
+ .map(type -> type.is(ValueType.STRING))
+ .orElse(false)) {
return StringComparators.LEXICOGRAPHIC;
} else {
return StringComparators.NUMERIC;
@@ -164,7 +173,7 @@ public class RowSignatures
* Creates a {@link ComplexSqlType} using the supplied {@link
RelDataTypeFactory} to ensure that the
* {@link ComplexSqlType} is interned. This is important because Calcite
checks that the references are equal
* instead of the objects being equivalent.
- *
+ * <p>
* This method uses {@link
RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) ensures that
if the
* type factory is a {@link
org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed
through
* {@link
org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which
interns the type.
@@ -179,15 +188,15 @@ public class RowSignatures
/**
* Calcite {@link RelDataType} for Druid complex columns, to preserve
complex type information.
- *
+ * <p>
* If using with other operations of a {@link RelDataTypeFactory}, consider
wrapping the creation of this type in
* {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean)
to ensure that if the type factory is a
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type
is passed through
* {@link
org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which
interns the type.
- *
+ * <p>
* If {@link SqlTypeName} is going to be {@link SqlTypeName#OTHER} and a
{@link RelDataTypeFactory} is available,
* consider using {@link #makeComplexType(RelDataTypeFactory, ColumnType,
boolean)}.
- *
+ * <p>
* This type does not work well with {@link
org.apache.calcite.sql.type.ReturnTypes#explicit(RelDataType)}, which
* will create new {@link RelDataType} using {@link SqlTypeName} during
return type inference, so implementors of
* {@link org.apache.druid.sql.calcite.expression.SqlOperatorConversion}
should implement the
@@ -235,4 +244,67 @@ public class RowSignatures
return columnType.asTypeString();
}
}
+
+ public static ComplexSqlSingleOperandTypeChecker
complexTypeChecker(ColumnType complexType)
+ {
+ return new ComplexSqlSingleOperandTypeChecker(
+ new ComplexSqlType(SqlTypeName.OTHER, complexType, true)
+ );
+ }
+
+ public static final class ComplexSqlSingleOperandTypeChecker implements
SqlSingleOperandTypeChecker
+ {
+ private final ComplexSqlType type;
+
+ public ComplexSqlSingleOperandTypeChecker(
+ ComplexSqlType type
+ )
+ {
+ this.type = type;
+ }
+
+ @Override
+ public boolean checkSingleOperandType(
+ SqlCallBinding callBinding,
+ SqlNode operand,
+ int iFormalOperand,
+ boolean throwOnFailure
+ )
+ {
+ return
type.equals(callBinding.getValidator().deriveType(callBinding.getScope(),
operand));
+ }
+
+ @Override
+ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean
throwOnFailure)
+ {
+ if (callBinding.getOperandCount() != 1) {
+ return false;
+ }
+ return checkSingleOperandType(callBinding, callBinding.operand(0), 0,
throwOnFailure);
+ }
+
+ @Override
+ public SqlOperandCountRange getOperandCountRange()
+ {
+ return SqlOperandCountRanges.of(1);
+ }
+
+ @Override
+ public String getAllowedSignatures(SqlOperator op, String opName)
+ {
+ return StringUtils.format("'%s'(%s)", opName, type);
+ }
+
+ @Override
+ public Consistency getConsistency()
+ {
+ return Consistency.NONE;
+ }
+
+ @Override
+ public boolean isOptional(int i)
+ {
+ return false;
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]