[ https://issues.apache.org/jira/browse/DRILL-6375?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16505453#comment-16505453 ]
ASF GitHub Bot commented on DRILL-6375: --------------------------------------- sohami closed pull request #1256: DRILL-6375 : Support for ANY_VALUE aggregate function URL: https://github.com/apache/drill/pull/1256 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFuncHolder.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFuncHolder.java index 8e7b645b0a..80f299ecdf 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFuncHolder.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFuncHolder.java @@ -130,7 +130,7 @@ public int getParamCount() { * @return workspace variables */ @Override - public JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables){ + public JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables, FieldReference fieldReference){ JVar[] workspaceJVars = new JVar[5]; workspaceJVars[0] = g.declareClassField("returnOI", g.getModel()._ref(ObjectInspector.class)); diff --git a/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd b/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd index 202f539cb1..3fb2601418 100644 --- a/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd +++ b/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd @@ -88,6 +88,52 @@ {inputType: "Interval", outputType: "NullableInterval", runningType: "Interval", major: "Date", initialValue: "0"}, {inputType: "NullableInterval", outputType: "NullableInterval", runningType: "Interval", major: "Date", initialValue: "0"} ] + }, + {className: "AnyValue", funcName: "any_value", types: [ + {inputType: "Bit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"}, + {inputType: "Int", outputType: "NullableInt", runningType: "Int", major: "Numeric"}, + {inputType: "BigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"}, + {inputType: "NullableInt", outputType: "NullableInt", runningType: "Int", major: "Numeric"}, + {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "Float4", outputType: "NullableFloat4", runningType: "Float4", major: "Numeric"}, + {inputType: "Float8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, + {inputType: "NullableFloat4", outputType: "NullableFloat4", runningType: "Float4", major: "Numeric"}, + {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, + {inputType: "Date", outputType: "NullableDate", runningType: "Date", major: "Date", initialValue: "0"}, + {inputType: "NullableDate", outputType: "NullableDate", runningType: "Date", major: "Date", initialValue: "0"}, + {inputType: "TimeStamp", outputType: "NullableTimeStamp", runningType: "TimeStamp", major: "Date", initialValue: "0"}, + {inputType: "NullableTimeStamp", outputType: "NullableTimeStamp", runningType: "TimeStamp", major: "Date", initialValue: "0"}, + {inputType: "Time", outputType: "NullableTime", runningType: "Time", major: "Date", initialValue: "0"}, + {inputType: "NullableTime", outputType: "NullableTime", runningType: "Time", major: "Date", initialValue: "0"}, + {inputType: "IntervalDay", outputType: "NullableIntervalDay", runningType: "IntervalDay", major: "Date", initialValue: "0"}, + {inputType: "NullableIntervalDay", outputType: "NullableIntervalDay", runningType: "IntervalDay", major: "Date", initialValue: "0"}, + {inputType: "IntervalYear", outputType: "NullableIntervalYear", runningType: "IntervalYear", major: "Date", initialValue: "0"}, + {inputType: "NullableIntervalYear", outputType: "NullableIntervalYear", runningType: "IntervalYear", major: "Date", initialValue: "0"}, + {inputType: "Interval", outputType: "NullableInterval", runningType: "Interval", major: "Date", initialValue: "0"}, + {inputType: "NullableInterval", outputType: "NullableInterval", runningType: "Interval", major: "Date", initialValue: "0"}, + {inputType: "VarChar", outputType: "NullableVarChar", runningType: "VarChar", major: "VarBytes", initialValue: ""}, + {inputType: "NullableVarChar", outputType: "NullableVarChar", runningType: "VarChar", major: "VarBytes", initialValue: ""}, + {inputType: "VarBinary", outputType: "NullableVarBinary", runningType: "VarBinary", major: "VarBytes"}, + {inputType: "NullableVarBinary", outputType: "NullableVarBinary", runningType: "VarBinary", major: "VarBytes"} + {inputType: "List", outputType: "List", runningType: "List", major: "Complex"} + {inputType: "Map", outputType: "Map", runningType: "Map", major: "Complex"} + {inputType: "RepeatedBit", outputType: "RepeatedBit", runningType: "RepeatedBit", major: "Complex"}, + {inputType: "RepeatedInt", outputType: "RepeatedInt", runningType: "RepeatedInt", major: "Complex"}, + {inputType: "RepeatedBigInt", outputType: "RepeatedBigInt", runningType: "RepeatedBigInt", major: "Complex"}, + {inputType: "RepeatedFloat4", outputType: "RepeatedFloat4", runningType: "RepeatedFloat4", major: "Complex"}, + {inputType: "RepeatedFloat8", outputType: "RepeatedFloat8", runningType: "RepeatedFloat8", major: "Complex"}, + {inputType: "RepeatedDate", outputType: "RepeatedDate", runningType: "RepeatedDate", major: "Complex"}, + {inputType: "RepeatedTimeStamp", outputType: "RepeatedTimeStamp", runningType: "RepeatedTimeStamp", major: "Complex"}, + {inputType: "RepeatedTime", outputType: "RepeatedTime", runningType: "RepeatedTime", major: "Complex"}, + {inputType: "RepeatedIntervalDay", outputType: "RepeatedIntervalDay", runningType: "RepeatedIntervalDay", major: "Complex"}, + {inputType: "RepeatedIntervalYear", outputType: "RepeatedIntervalYear", runningType: "RepeatedIntervalYear", major: "Complex"}, + {inputType: "RepeatedInterval", outputType: "RepeatedInterval", runningType: "RepeatedInterval", major: "Complex"}, + {inputType: "RepeatedVarChar", outputType: "RepeatedVarChar", runningType: "RepeatedVarChar", major: "Complex"}, + {inputType: "RepeatedVarBinary", outputType: "RepeatedVarBinary", runningType: "RepeatedVarBinary", major: "Complex"}, + {inputType: "RepeatedList", outputType: "RepeatedList", runningType: "RepeatedList", major: "Complex"}, + {inputType: "RepeatedMap", outputType: "RepeatedMap", runningType: "RepeatedMap", major: "Complex"} + ] } ] } diff --git a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd index 7da2d071f1..003bbfa1b2 100644 --- a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd +++ b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd @@ -35,6 +35,12 @@ {inputType: "VarDecimal", outputType: "NullableVarDecimal"}, {inputType: "NullableVarDecimal", outputType: "NullableVarDecimal"} ] + }, + {className: "AnyValue", funcName: "any_value", types: [ + {inputType: "VarDecimal", outputType: "NullableVarDecimal"}, + {inputType: "NullableVarDecimal", outputType: "NullableVarDecimal"} + {inputType: "RepeatedVarDecimal", outputType: "RepeatedVarDecimal"} + ] } ] } diff --git a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java index ebf20e53e1..59d37157f7 100644 --- a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java +++ b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java @@ -61,11 +61,11 @@ public void setup() { value = new ${type.runningType}Holder(); nonNullCount = new BigIntHolder(); nonNullCount.value = 0; - <#if aggrtype.funcName == "sum"> + <#if aggrtype.funcName == "sum" || aggrtype.funcName == "any_value"> value.value = 0; <#elseif aggrtype.funcName == "min"> <#if type.runningType?starts_with("Bit")> - value.value = 1; + value.value = 1; <#elseif type.runningType?starts_with("Int")> value.value = Integer.MAX_VALUE; <#elseif type.runningType?starts_with("BigInt")> @@ -77,7 +77,7 @@ public void setup() { </#if> <#elseif aggrtype.funcName == "max"> <#if type.runningType?starts_with("Bit")> - value.value = 0; + value.value = 0; <#elseif type.runningType?starts_with("Int")> value.value = Integer.MIN_VALUE; <#elseif type.runningType?starts_with("BigInt")> @@ -110,19 +110,21 @@ public void add() { value.value = Float.isNaN(value.value) ? in.value : Math.min(value.value, in.value); } <#elseif type.inputType?contains("Float8")> - if(!Double.isNaN(in.value)) { - value.value = Double.isNaN(value.value) ? in.value : Math.min(value.value, in.value); - } - <#else> + if(!Double.isNaN(in.value)) { + value.value = Double.isNaN(value.value) ? in.value : Math.min(value.value, in.value); + } + <#else> value.value = Math.min(value.value, in.value); - </#if> + </#if> <#elseif aggrtype.funcName == "max"> value.value = Math.max(value.value, in.value); <#elseif aggrtype.funcName == "sum"> value.value += in.value; <#elseif aggrtype.funcName == "count"> value.value++; - <#else> + <#elseif aggrtype.funcName == "any_value"> + value.value = in.value; + <#else> // TODO: throw an error ? </#if> <#if type.inputType?starts_with("Nullable")> @@ -143,7 +145,7 @@ public void output() { @Override public void reset() { nonNullCount.value = 0; - <#if aggrtype.funcName == "sum" || aggrtype.funcName == "count"> + <#if aggrtype.funcName == "sum" || aggrtype.funcName == "count" || aggrtype.funcName == "any_value"> value.value = 0; <#elseif aggrtype.funcName == "min"> <#if type.runningType?starts_with("Int")> diff --git a/exec/java-exec/src/main/codegen/templates/ComplexAggrFunctions1.java b/exec/java-exec/src/main/codegen/templates/ComplexAggrFunctions1.java new file mode 100644 index 0000000000..6aa92e3c09 --- /dev/null +++ b/exec/java-exec/src/main/codegen/templates/ComplexAggrFunctions1.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +<@pp.dropOutputFile /> + + + +<#list aggrtypes1.aggrtypes as aggrtype> +<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gaggr/${aggrtype.className}ComplexFunctions.java" /> + +<#include "/@includes/license.ftl" /> + +/* + * This class is generated using freemarker and the ${.template_name} template. + */ + +<#-- A utility class that is used to generate java code for aggr functions that maintain a single --> +<#-- running counter to hold the result. This includes: ANY_VALUE. --> + +package org.apache.drill.exec.expr.fn.impl.gaggr; + +import org.apache.drill.exec.expr.DrillAggFunc; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope; +import org.apache.drill.exec.expr.annotations.Output; +import org.apache.drill.exec.expr.annotations.Param; +import org.apache.drill.exec.expr.annotations.Workspace; +import org.apache.drill.exec.expr.holders.*; +import org.apache.drill.exec.vector.complex.reader.FieldReader; +import org.apache.drill.exec.vector.complex.MapUtility; +import org.apache.drill.exec.vector.complex.writer.*; +import org.apache.drill.exec.vector.complex.writer.BaseWriter.*; + +@SuppressWarnings("unused") + +public class ${aggrtype.className}ComplexFunctions { +static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(${aggrtype.className}ComplexFunctions.class); + +<#list aggrtype.types as type> +<#if type.major == "Complex"> + +@FunctionTemplate(name = "${aggrtype.funcName}", scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE) +public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{ + @Param ${type.inputType}Holder inHolder; + @Workspace BigIntHolder nonNullCount; + @Output ComplexWriter writer; + + public void setup() { + nonNullCount = new BigIntHolder(); + nonNullCount.value = 0; + } + + @Override + public void add() { + <#if type.inputType?starts_with("Nullable")> + sout: { + if (inHolder.isSet == 0) { + // processing nullable input and the value is null, so don't do anything... + break sout; + } + </#if> + <#if aggrtype.funcName == "any_value"> + <#if type.runningType?starts_with("Map")> + if (nonNullCount.value == 0) { + org.apache.drill.exec.expr.fn.impl.MappifyUtility.createMap(inHolder.reader, writer, "any_value"); + } + <#elseif type.runningType?starts_with("RepeatedMap")> + if (nonNullCount.value == 0) { + org.apache.drill.exec.expr.fn.impl.MappifyUtility.createRepeatedMapOrList(inHolder.reader, writer, "any_value"); + } + <#elseif type.runningType?starts_with("List")> + if (nonNullCount.value == 0) { + org.apache.drill.exec.expr.fn.impl.MappifyUtility.createList(inHolder.reader, writer, "any_value"); + } + <#elseif type.runningType?starts_with("RepeatedList")> + if (nonNullCount.value == 0) { + org.apache.drill.exec.expr.fn.impl.MappifyUtility.createRepeatedMapOrList(inHolder.reader, writer, "any_value"); + } + <#elseif type.runningType?starts_with("Repeated")> + if (nonNullCount.value == 0) { + org.apache.drill.exec.expr.fn.impl.MappifyUtility.createList(inHolder.reader, writer, "any_value"); + } + </#if> + </#if> + nonNullCount.value = 1; + <#if type.inputType?starts_with("Nullable")> + } // end of sout block + </#if> + } + + @Override + public void output() { + //Do nothing since the complex writer takes care of everything! + } + + @Override + public void reset() { + <#if aggrtype.funcName == "any_value"> + nonNullCount.value = 0; + </#if> + } +} +</#if> +</#list> +} +</#list> \ No newline at end of file diff --git a/exec/java-exec/src/main/codegen/templates/DateIntervalAggrFunctions1.java b/exec/java-exec/src/main/codegen/templates/DateIntervalAggrFunctions1.java index f526575b02..8080ea76b5 100644 --- a/exec/java-exec/src/main/codegen/templates/DateIntervalAggrFunctions1.java +++ b/exec/java-exec/src/main/codegen/templates/DateIntervalAggrFunctions1.java @@ -131,7 +131,16 @@ public void add() { </#if> <#elseif aggrtype.funcName == "count"> value.value++; - <#else> + <#elseif aggrtype.funcName == "any_value"> + <#if type.outputType?ends_with("Interval")> + value.days = in.days; + value.months = in.months; + value.milliseconds = in.milliseconds; + <#elseif type.outputType?ends_with("IntervalDay")> + value.days = in.days; + value.milliseconds = in.milliseconds; + </#if> + <#else> // TODO: throw an error ? </#if> <#if type.inputType?starts_with("Nullable")> diff --git a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java index 6b23f92edc..7f4ca154cd 100644 --- a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java +++ b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java @@ -39,6 +39,8 @@ import org.apache.drill.exec.expr.annotations.Output; import org.apache.drill.exec.expr.annotations.Param; import org.apache.drill.exec.expr.annotations.Workspace; +import org.apache.drill.exec.vector.complex.writer.*; +import org.apache.drill.exec.vector.complex.writer.BaseWriter.*; import javax.inject.Inject; import io.netty.buffer.DrillBuf; import org.apache.drill.exec.expr.holders.*; @@ -124,6 +126,101 @@ public void reset() { nonNullCount.value = 0; } } + <#elseif aggrtype.funcName.contains("any_value") && type.inputType?starts_with("Repeated")> + @FunctionTemplate(name = "${aggrtype.funcName}", + scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE, + returnType = FunctionTemplate.ReturnType.DECIMAL_AGGREGATE) + public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc { + @Param ${type.inputType}Holder in; + @Output ComplexWriter writer; + @Workspace BigIntHolder nonNullCount; + + public void setup() { + nonNullCount = new BigIntHolder(); + } + + @Override + public void add() { + if (nonNullCount.value == 0) { + org.apache.drill.exec.expr.fn.impl.MappifyUtility.createList(in.reader, writer, "any_value"); + } + nonNullCount.value = 1; + } + + @Override + public void output() { + } + + @Override + public void reset() { + nonNullCount.value = 0; + } + } + <#elseif aggrtype.funcName.contains("any_value")> + @FunctionTemplate(name = "${aggrtype.funcName}", + scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE, + returnType = FunctionTemplate.ReturnType.DECIMAL_AGGREGATE) + public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc { + @Param ${type.inputType}Holder in; + @Inject DrillBuf buffer; + @Workspace ObjectHolder value; + @Workspace IntHolder scale; + @Workspace IntHolder precision; + @Output ${type.outputType}Holder out; + @Workspace BigIntHolder nonNullCount; + + public void setup() { + value = new ObjectHolder(); + value.obj = java.math.BigDecimal.ZERO; + nonNullCount = new BigIntHolder(); + } + + @Override + public void add() { + <#if type.inputType?starts_with("Nullable")> + sout: { + if (in.isSet == 0) { + // processing nullable input and the value is null, so don't do anything... + break sout; + } + </#if> + if (nonNullCount.value == 0) { + value.obj=org.apache.drill.exec.util.DecimalUtility + .getBigDecimalFromDrillBuf(in.buffer,in.start,in.end-in.start,in.scale); + scale.value = in.scale; + precision.value = in.precision; + } + nonNullCount.value = 1; + <#if type.inputType?starts_with("Nullable")> + } // end of sout block + </#if> + } + + @Override + public void output() { + if (nonNullCount.value > 0) { + out.isSet = 1; + byte[] bytes = ((java.math.BigDecimal)value.obj).unscaledValue().toByteArray(); + int len = bytes.length; + out.start = 0; + out.buffer = buffer.reallocIfNeeded(len); + out.buffer.setBytes(0, bytes); + out.end = len; + out.scale = scale.value; + out.precision = precision.value; + } else { + out.isSet = 0; + } + } + + @Override + public void reset() { + scale.value = 0; + precision.value = 0; + value.obj = null; + nonNullCount.value = 0; + } + } <#elseif aggrtype.funcName == "max" || aggrtype.funcName == "min"> @FunctionTemplate(name = "${aggrtype.funcName}", diff --git a/exec/java-exec/src/main/codegen/templates/VarCharAggrFunctions1.java b/exec/java-exec/src/main/codegen/templates/VarCharAggrFunctions1.java index a5afce98cf..de5d705e30 100644 --- a/exec/java-exec/src/main/codegen/templates/VarCharAggrFunctions1.java +++ b/exec/java-exec/src/main/codegen/templates/VarCharAggrFunctions1.java @@ -90,6 +90,16 @@ public void add() { break sout; } </#if> + <#if aggrtype.className == "AnyValue"> + if (nonNullCount.value == 0) { + nonNullCount.value = 1; + int inputLength = in.end - in.start; + org.apache.drill.exec.expr.fn.impl.DrillByteArray tmp = (org.apache.drill.exec.expr.fn.impl.DrillByteArray) value.obj; + byte[] tempArray = new byte[inputLength]; + in.buffer.getBytes(in.start, tempArray, 0, inputLength); + tmp.setBytes(tempArray); + } + <#else> nonNullCount.value = 1; org.apache.drill.exec.expr.fn.impl.DrillByteArray tmp = (org.apache.drill.exec.expr.fn.impl.DrillByteArray) value.obj; int cmp = 0; @@ -121,6 +131,7 @@ public void add() { tmp.setBytes(tempArray); } } + </#if> <#if type.inputType?starts_with("Nullable")> } // end of sout block </#if> diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/sig/ConstantExpressionIdentifier.java b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/sig/ConstantExpressionIdentifier.java index d7646633d6..0175d51a9c 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/sig/ConstantExpressionIdentifier.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/sig/ConstantExpressionIdentifier.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Set; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -234,6 +235,12 @@ public Boolean visitConvertExpression(ConvertExpression e, return e.getInput().accept(this, value); } + @Override + public Boolean visitAnyValueExpression(AnyValueExpression e, + IdentityHashMap<LogicalExpression, Object> value) throws RuntimeException { + return e.getInput().accept(this, value); + } + @Override public Boolean visitParameter(ValueExpressions.ParameterExpression e, IdentityHashMap<LogicalExpression, Object> value) throws RuntimeException { return false; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java index 64cfe66953..4486972ee1 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java @@ -24,6 +24,7 @@ import java.util.Set; import java.util.Stack; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -181,7 +182,7 @@ public HoldingContainer visitFunctionHolderExpression(FunctionHolderExpression h AbstractFuncHolder holder = (AbstractFuncHolder) holderExpr.getHolder(); - JVar[] workspaceVars = holder.renderStart(generator, null); + JVar[] workspaceVars = holder.renderStart(generator, null, holderExpr.getFieldReference()); if (holder.isNested()) { generator.getMappingSet().enterChild(); @@ -456,8 +457,7 @@ private HoldingContainer visitValueVectorReadExpression(ValueVectorReadExpressio generator.getEvalBlock().add(eval); } else { - JExpression vector = e.isSuperReader() ? vv1.component(componentVariable) : vv1; - JExpression expr = vector.invoke("getReader"); + JExpression expr = vv1.invoke("getReader"); PathSegment seg = e.getReadPath(); JVar isNull = null; @@ -713,6 +713,17 @@ public HoldingContainer visitConvertExpression(ConvertExpression e, ClassGenerat return fc.accept(this, value); } + @Override + public HoldingContainer visitAnyValueExpression(AnyValueExpression e, ClassGenerator<?> value) + throws RuntimeException { + + List<LogicalExpression> newArgs = Lists.newArrayList(); + newArgs.add(e.getInput()); // input_expr + + FunctionCall fc = new FunctionCall(AnyValueExpression.ANY_VALUE, newArgs, e.getPosition()); + return fc.accept(this, value); + } + private HoldingContainer visitBooleanAnd(BooleanOperator op, ClassGenerator<?> generator) { diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/AbstractFuncHolder.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/AbstractFuncHolder.java index 4902260f2c..7dd58ace24 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/AbstractFuncHolder.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/AbstractFuncHolder.java @@ -32,7 +32,7 @@ public abstract class AbstractFuncHolder implements FuncHolder { - public abstract JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables); + public abstract JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables, FieldReference fieldReference); public void renderMiddle(ClassGenerator<?> g, HoldingContainer[] inputVariables, JVar[] workspaceJVars) { // default implementation is add no code diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillAggFuncHolder.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillAggFuncHolder.java index e1cd96fefb..1a5df670f6 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillAggFuncHolder.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillAggFuncHolder.java @@ -21,6 +21,7 @@ import org.apache.drill.common.exceptions.DrillRuntimeException; import org.apache.drill.common.expression.FieldReference; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.types.TypeProtos.DataMode; import org.apache.drill.common.types.TypeProtos.MajorType; import org.apache.drill.common.types.Types; @@ -44,19 +45,19 @@ class DrillAggFuncHolder extends DrillFuncHolder { static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillAggFuncHolder.class); - private String setup() { + protected String setup() { return meth("setup"); } - private String reset() { + protected String reset() { return meth("reset", false); } - private String add() { + protected String add() { return meth("add"); } - private String output() { + protected String output() { return meth("output"); } - private String cleanup() { + protected String cleanup() { return meth("cleanup", false); } @@ -78,7 +79,7 @@ public boolean isAggregating() { } @Override - public JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables) { + public JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables, FieldReference fieldReference) { if (!g.getMappingSet().isHashAggMapping()) { //Declare workspace vars for non-hash-aggregation. JVar[] workspaceJVars = declareWorkspaceVariables(g); generateBody(g, BlockType.SETUP, setup(), null, workspaceJVars, true); @@ -128,12 +129,20 @@ public void renderMiddle(ClassGenerator<?> g, HoldingContainer[] inputVariables, @Override public HoldingContainer renderEnd(ClassGenerator<?> classGenerator, HoldingContainer[] inputVariables, JVar[] workspaceJVars, FieldReference fieldReference) { - HoldingContainer out = classGenerator.declare(getReturnType(), false); + HoldingContainer out = null; + JVar internalOutput = null; + if (getReturnType().getMinorType() != TypeProtos.MinorType.LATE) { + out = classGenerator.declare(getReturnType(), false); + } JBlock sub = new JBlock(); + if (getReturnType().getMinorType() != TypeProtos.MinorType.LATE) { + internalOutput = sub.decl(JMod.FINAL, classGenerator.getHolderType(getReturnType()), getReturnValue().getName(), JExpr._new(classGenerator.getHolderType(getReturnType()))); + } classGenerator.getEvalBlock().add(sub); - JVar internalOutput = sub.decl(JMod.FINAL, classGenerator.getHolderType(getReturnType()), getReturnValue().getName(), JExpr._new(classGenerator.getHolderType(getReturnType()))); addProtectedBlock(classGenerator, sub, output(), null, workspaceJVars, false); - sub.assign(out.getHolder(), internalOutput); + if (getReturnType().getMinorType() != TypeProtos.MinorType.LATE) { + sub.assign(out.getHolder(), internalOutput); + } //hash aggregate uses workspace vectors. Initialization is done in "setup" and does not require "reset" block. if (!classGenerator.getMappingSet().isHashAggMapping()) { generateBody(classGenerator, BlockType.RESET, reset(), null, workspaceJVars, false); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillComplexWriterAggFuncHolder.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillComplexWriterAggFuncHolder.java new file mode 100644 index 0000000000..44766bdbb3 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillComplexWriterAggFuncHolder.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.drill.exec.expr.fn; + +import org.apache.drill.common.expression.FieldReference; +import org.apache.drill.common.types.TypeProtos; +import org.apache.drill.exec.expr.ClassGenerator; +import org.apache.drill.exec.expr.ClassGenerator.HoldingContainer; +import org.apache.drill.exec.physical.impl.aggregate.StreamingAggBatch; +import org.apache.drill.exec.physical.impl.aggregate.StreamingAggTemplate; +import org.apache.drill.exec.record.VectorAccessibleComplexWriter; +import org.apache.drill.exec.vector.complex.writer.BaseWriter.ComplexWriter; + +import com.sun.codemodel.JBlock; +import com.sun.codemodel.JClass; +import com.sun.codemodel.JExpr; +import com.sun.codemodel.JExpression; +import com.sun.codemodel.JInvocation; +import com.sun.codemodel.JVar; +import com.sun.codemodel.JMod; + +public class DrillComplexWriterAggFuncHolder extends DrillAggFuncHolder { + + // Complex writer to write out complex data-types e.g. repeated maps/lists + private JVar complexWriter; + // The index at which to write - important when group-by is present. Implicit assumption that the output indexes + // will be sequential starting from 0. i.e. the first group would be written at index 0, second group at index 1 + // and so on. + private JVar writerIdx; + private JVar lastWriterIdx; + public DrillComplexWriterAggFuncHolder(FunctionAttributes functionAttributes, FunctionInitializer initializer) { + super(functionAttributes, initializer); + } + + @Override + public boolean isComplexWriterFuncHolder() { + return true; + } + + @Override + public JVar[] renderStart(ClassGenerator<?> classGenerator, HoldingContainer[] inputVariables, FieldReference fieldReference) { + if (!classGenerator.getMappingSet().isHashAggMapping()) { //Declare workspace vars for non-hash-aggregation. + JInvocation container = classGenerator.getMappingSet().getOutgoing().invoke("getOutgoingContainer"); + + complexWriter = classGenerator.declareClassField("complexWriter", classGenerator.getModel()._ref(ComplexWriter.class)); + writerIdx = classGenerator.declareClassField("writerIdx", classGenerator.getModel()._ref(int.class)); + lastWriterIdx = classGenerator.declareClassField("lastWriterIdx", classGenerator.getModel()._ref(int.class)); + //Default name is "col", if not passed in a reference name for the output vector. + String refName = fieldReference == null ? "col" : fieldReference.getRootSegment().getPath(); + JClass cwClass = classGenerator.getModel().ref(VectorAccessibleComplexWriter.class); + classGenerator.getSetupBlock().assign(complexWriter, cwClass.staticInvoke("getWriter").arg(refName).arg(container)); + classGenerator.getSetupBlock().assign(writerIdx, JExpr.lit(0)); + classGenerator.getSetupBlock().assign(lastWriterIdx, JExpr.lit(-1)); + + JVar[] workspaceJVars = declareWorkspaceVariables(classGenerator); + generateBody(classGenerator, ClassGenerator.BlockType.SETUP, setup(), null, workspaceJVars, true); + return workspaceJVars; + } else { + return super.renderStart(classGenerator, inputVariables, fieldReference); + } + } + + @Override + public void renderMiddle(ClassGenerator<?> classGenerator, HoldingContainer[] inputVariables, JVar[] workspaceJVars) { + + classGenerator.getEvalBlock().directStatement(String.format("//---- start of eval portion of %s function. ----//", + getRegisteredNames()[0])); + + JBlock sub = new JBlock(true, true); + JBlock topSub = sub; + JClass aggBatchClass = null; + + if (classGenerator.getCodeGenerator().getDefinition() == StreamingAggTemplate.TEMPLATE_DEFINITION) { + aggBatchClass = classGenerator.getModel().ref(StreamingAggBatch.class); + } + assert aggBatchClass != null : "ComplexWriterAggFuncHolder should only be used with an Aggregate Operator"; + + JExpression aggBatch = JExpr.cast(aggBatchClass, classGenerator.getMappingSet().getOutgoing()); + + classGenerator.getSetupBlock().add(aggBatch.invoke("addComplexWriter").arg(complexWriter)); + // Only set the writer if there is a position change. Calling setPosition may cause underlying writers to allocate + // new vectors, thereby, losing the previously stored values + JBlock condAssignCW = classGenerator.getEvalBlock()._if(lastWriterIdx.ne(writerIdx))._then(); + condAssignCW.add(complexWriter.invoke("setPosition").arg(writerIdx)); + condAssignCW.assign(lastWriterIdx, writerIdx); + sub.decl(classGenerator.getModel()._ref(ComplexWriter.class), getReturnValue().getName(), complexWriter); + + // add the subblock after the out declaration. + classGenerator.getEvalBlock().add(topSub); + + addProtectedBlock(classGenerator, sub, add(), inputVariables, workspaceJVars, false); + classGenerator.getEvalBlock().directStatement(String.format("//---- end of eval portion of %s function. ----//", + getRegisteredNames()[0])); + } + + @Override + public HoldingContainer renderEnd(ClassGenerator<?> classGenerator, HoldingContainer[] inputVariables, + JVar[] workspaceJVars, FieldReference fieldReference) { + HoldingContainer out = null; + JVar internalOutput = null; + if (getReturnType().getMinorType() != TypeProtos.MinorType.LATE) { + out = classGenerator.declare(getReturnType(), false); + } + JBlock sub = new JBlock(); + if (getReturnType().getMinorType() != TypeProtos.MinorType.LATE) { + internalOutput = sub.decl(JMod.FINAL, classGenerator.getHolderType(getReturnType()), getReturnValue().getName(), + JExpr._new(classGenerator.getHolderType(getReturnType()))); + } + classGenerator.getEvalBlock().add(sub); + if (getReturnType().getMinorType() == TypeProtos.MinorType.LATE) { + sub.assignPlus(writerIdx, JExpr.lit(1)); + } + addProtectedBlock(classGenerator, sub, output(), null, workspaceJVars, false); + if (getReturnType().getMinorType() != TypeProtos.MinorType.LATE) { + sub.assign(out.getHolder(), internalOutput); + } + //hash aggregate uses workspace vectors. Initialization is done in "setup" and does not require "reset" block. + if (!classGenerator.getMappingSet().isHashAggMapping()) { + generateBody(classGenerator, ClassGenerator.BlockType.RESET, reset(), null, workspaceJVars, false); + } + generateBody(classGenerator, ClassGenerator.BlockType.CLEANUP, cleanup(), null, workspaceJVars, false); + + return out; + } +} + diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java index 9df5305125..240ff27d7a 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java @@ -23,8 +23,10 @@ import org.apache.drill.common.exceptions.DrillRuntimeException; import org.apache.drill.common.exceptions.UserException; import org.apache.drill.common.expression.ExpressionPosition; +import org.apache.drill.common.expression.FieldReference; import org.apache.drill.common.expression.FunctionHolderExpression; import org.apache.drill.common.expression.LogicalExpression; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.types.TypeProtos.MajorType; import org.apache.drill.common.types.TypeProtos.MinorType; import org.apache.drill.common.types.Types; @@ -36,6 +38,9 @@ import org.apache.drill.exec.expr.DrillFuncHolderExpr; import org.apache.drill.exec.expr.TypeHelper; import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling; +import org.apache.drill.exec.expr.holders.ListHolder; +import org.apache.drill.exec.expr.holders.MapHolder; +import org.apache.drill.exec.expr.holders.RepeatedMapHolder; import org.apache.drill.exec.ops.UdfUtilities; import org.apache.drill.exec.vector.complex.reader.FieldReader; @@ -80,7 +85,7 @@ protected String meth(String methodName, boolean required) { } @Override - public JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables) { + public JVar[] renderStart(ClassGenerator<?> g, HoldingContainer[] inputVariables, FieldReference fieldReference) { return declareWorkspaceVariables(g); } @@ -186,12 +191,35 @@ protected void addProtectedBlock(ClassGenerator<?> g, JBlock sub, String body, H ValueReference parameter = attributes.getParameters()[i]; HoldingContainer inputVariable = inputVariables[i]; - if (parameter.isFieldReader() && ! inputVariable.isReader() && ! Types.isComplex(inputVariable.getMajorType()) && inputVariable.getMinorType() != MinorType.UNION) { + if (parameter.isFieldReader() && ! inputVariable.isReader() + && ! Types.isComplex(inputVariable.getMajorType()) && inputVariable.getMinorType() != MinorType.UNION) { JType singularReaderClass = g.getModel()._ref(TypeHelper.getHolderReaderImpl(inputVariable.getMajorType().getMinorType(), inputVariable.getMajorType().getMode())); JType fieldReadClass = g.getModel()._ref(FieldReader.class); sub.decl(fieldReadClass, parameter.getName(), JExpr._new(singularReaderClass).arg(inputVariable.getHolder())); - } else { + } else if (!parameter.isFieldReader() && inputVariable.isReader() && Types.isComplex(parameter.getType())) { + // For complex data-types (repeated maps/lists) the input to the aggregate will be a FieldReader. However, aggregate + // functions like ANY_VALUE, will assume the input to be a RepeatedMapHolder etc. Generate boilerplate code, to map + // from FieldReader to respective Holder. + if (parameter.getType().getMinorType() == MinorType.MAP) { + JType holderClass; + if (parameter.getType().getMode() == TypeProtos.DataMode.REPEATED) { + holderClass = g.getModel()._ref(RepeatedMapHolder.class); + JVar holderVar = sub.decl(holderClass, parameter.getName(), JExpr._new(holderClass)); + sub.assign(holderVar.ref("reader"), inputVariable.getHolder()); + } else { + holderClass = g.getModel()._ref(MapHolder.class); + JVar holderVar = sub.decl(holderClass, parameter.getName(), JExpr._new(holderClass)); + sub.assign(holderVar.ref("reader"), inputVariable.getHolder()); + } + } else if (parameter.getType().getMinorType() == MinorType.LIST) { + //TODO: Add support for REPEATED LISTs + JType holderClass = g.getModel()._ref(ListHolder.class); + JVar holderVar = sub.decl(holderClass, parameter.getName(), JExpr._new(holderClass)); + sub.assign(holderVar.ref("reader"), inputVariable.getHolder()); + } + } + else { sub.decl(inputVariable.getHolder().type(), parameter.getName(), inputVariable.getHolder()); } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java index ca5605a582..b5a2f0700d 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java @@ -181,7 +181,9 @@ switch (template.scope()) { case POINT_AGGREGATE: - return new DrillAggFuncHolder(functionAttributes, initializer); + return outputField.isComplexWriter() ? + new DrillComplexWriterAggFuncHolder(functionAttributes, initializer) : + new DrillAggFuncHolder(functionAttributes, initializer); case SIMPLE: return outputField.isComplexWriter() ? new DrillComplexWriterFuncHolder(functionAttributes, initializer) : diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/Mappify.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/Mappify.java index 703d62e02e..3db9f5ac39 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/Mappify.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/Mappify.java @@ -60,7 +60,7 @@ public void setup() { } public void eval() { - buffer = org.apache.drill.exec.expr.fn.impl.MappifyUtility.mappify(reader, writer, buffer); + buffer = org.apache.drill.exec.expr.fn.impl.MappifyUtility.mappify(reader, writer, buffer, "Mappify/kvgen"); } } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/MappifyUtility.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/MappifyUtility.java index 3745fe2b34..b3fca2bb64 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/MappifyUtility.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/MappifyUtility.java @@ -37,7 +37,7 @@ public static final String fieldKey = "key"; public static final String fieldValue = "value"; - public static DrillBuf mappify(FieldReader reader, BaseWriter.ComplexWriter writer, DrillBuf buffer) { + public static DrillBuf mappify(FieldReader reader, BaseWriter.ComplexWriter writer, DrillBuf buffer, String caller) { // Currently we expect single map as input if (DataMode.REPEATED == reader.getType().getMode() || !(reader.getType().getMinorType() == TypeProtos.MinorType.MAP)) { throw new DrillRuntimeException("kvgen function only supports Simple maps as input"); @@ -72,7 +72,7 @@ public static DrillBuf mappify(FieldReader reader, BaseWriter.ComplexWriter writ mapWriter.varChar(fieldKey).write(vh); // Write the value to the map - MapUtility.writeToMapFromReader(fieldReader, mapWriter); + MapUtility.writeToMapFromReader(fieldReader, mapWriter, caller); mapWriter.end(); } @@ -80,5 +80,35 @@ public static DrillBuf mappify(FieldReader reader, BaseWriter.ComplexWriter writ return buffer; } + + public static void createRepeatedMapOrList(FieldReader reader, BaseWriter.ComplexWriter writer, String caller) { + if (DataMode.REPEATED != reader.getType().getMode()) { + throw new DrillRuntimeException("Do not invoke createRepeatedMapOrList() unless MINOR mode is REPEATED"); + } + BaseWriter.ListWriter listWriter = writer.rootAsList(); + MapUtility.writeToListFromReader(reader, listWriter, caller); + } + + public static void createMap(FieldReader reader, BaseWriter.ComplexWriter writer, String caller) { + if (DataMode.REPEATED == reader.getType().getMode()) { + throw new DrillRuntimeException("Do not invoke createMap() with REPEATED MINOR mode"); + } + if (reader.getType().getMinorType() == TypeProtos.MinorType.MAP) { + BaseWriter.MapWriter mapWriter = writer.rootAsMap(); + // Iterate over the fields in the map + Iterator<String> fieldIterator = reader.iterator(); + while (fieldIterator.hasNext()) { + String field = fieldIterator.next(); + FieldReader fieldReader = reader.reader(field); + // Write the value to the map + MapUtility.writeToMapFromReader(fieldReader, mapWriter, field, caller); + } + } + } + + public static void createList(FieldReader reader, BaseWriter.ComplexWriter writer, String caller) { + BaseWriter.ListWriter listWriter = writer.rootAsList(); + MapUtility.writeToListFromReader(reader, listWriter, caller); + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/StreamingAggBatch.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/StreamingAggBatch.java index 34ab97e2c0..caeed50df3 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/StreamingAggBatch.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/StreamingAggBatch.java @@ -18,7 +18,9 @@ package org.apache.drill.exec.physical.impl.aggregate; import java.io.IOException; +import java.util.List; +import com.google.common.collect.Lists; import org.apache.drill.common.exceptions.DrillRuntimeException; import org.apache.drill.common.exceptions.UserException; import org.apache.drill.common.expression.ErrorCollector; @@ -35,6 +37,7 @@ import org.apache.drill.exec.expr.ClassGenerator; import org.apache.drill.exec.expr.ClassGenerator.HoldingContainer; import org.apache.drill.exec.expr.CodeGenerator; +import org.apache.drill.exec.expr.DrillFuncHolderExpr; import org.apache.drill.exec.expr.ExpressionTreeMaterializer; import org.apache.drill.exec.expr.HoldingContainerExpression; import org.apache.drill.exec.expr.TypeHelper; @@ -50,21 +53,26 @@ import org.apache.drill.exec.record.MaterializedField; import org.apache.drill.exec.record.RecordBatch; import org.apache.drill.exec.record.TypedFieldId; +import org.apache.drill.exec.record.VectorContainer; import org.apache.drill.exec.record.VectorWrapper; import org.apache.drill.exec.record.selection.SelectionVector2; import org.apache.drill.exec.record.selection.SelectionVector4; import org.apache.drill.exec.vector.AllocationHelper; import org.apache.drill.exec.vector.FixedWidthVector; +import org.apache.drill.exec.vector.UntypedNullHolder; +import org.apache.drill.exec.vector.UntypedNullVector; import org.apache.drill.exec.vector.ValueVector; import com.sun.codemodel.JExpr; import com.sun.codemodel.JVar; +import org.apache.drill.exec.vector.complex.writer.BaseWriter; public class StreamingAggBatch extends AbstractRecordBatch<StreamingAggregate> { static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(StreamingAggBatch.class); private StreamingAggregator aggregator; private final RecordBatch incoming; + private List<BaseWriter.ComplexWriter> complexWriters; private boolean done = false; private boolean first = true; private int recordCount = 0; @@ -106,6 +114,11 @@ public int getRecordCount() { return recordCount; } + @Override + public VectorContainer getOutgoingContainer() { + return this.container; + } + @Override public void buildSchema() throws SchemaChangeException { IterOutcome outcome = next(incoming); @@ -131,6 +144,10 @@ public void buildSchema() throws SchemaChangeException { for (final VectorWrapper<?> w : container) { w.getValueVector().allocateNew(); } + + if (complexWriters != null) { + container.buildSchema(SelectionVectorMode.NONE); + } } @Override @@ -177,7 +194,6 @@ public IterOutcome innerNext() { throw new IllegalStateException(String.format("unknown outcome %s", outcome)); } } - AggOutcome out = aggregator.doWork(); recordCount = aggregator.getOutputCount(); logger.debug("Aggregator response {}, records {}", out, aggregator.getOutputCount()); @@ -191,6 +207,11 @@ public IterOutcome innerNext() { // fall through case RETURN_OUTCOME: IterOutcome outcome = aggregator.getOutcome(); + // In case of complex writer expression, vectors would be added to batch run-time. + // We have to re-build the schema. + if (complexWriters != null) { + container.buildSchema(SelectionVectorMode.NONE); + } if (outcome == IterOutcome.NONE && first) { first = false; done = true; @@ -213,6 +234,14 @@ public IterOutcome innerNext() { } } + private void allocateComplexWriters() { + // Allocate the complex writers before processing the incoming batch + if (complexWriters != null) { + for (final BaseWriter.ComplexWriter writer : complexWriters) { + writer.allocate(); + } + } + } /** * Method is invoked when we have a straight aggregate (no group by expression) and our input is empty. @@ -272,9 +301,15 @@ private boolean createAggregator() { } } + public void addComplexWriter(final BaseWriter.ComplexWriter writer) { + complexWriters.add(writer); + } + private StreamingAggregator createAggregatorInternal() throws SchemaChangeException, ClassTransformationException, IOException{ ClassGenerator<StreamingAggregator> cg = CodeGenerator.getRoot(StreamingAggTemplate.TEMPLATE_DEFINITION, context.getOptions()); cg.getCodeGenerator().plainJavaCapable(true); + // Uncomment out this line to debug the generated code. + //cg.getCodeGenerator().saveCodeForDebugging(true); container.clear(); LogicalExpression[] keyExprs = new LogicalExpression[popConfig.getKeys().size()]; @@ -307,12 +342,29 @@ private StreamingAggregator createAggregatorInternal() throws SchemaChangeExcept continue; } - final MaterializedField outputField = MaterializedField.create(ne.getRef().getLastSegment().getNameSegment().getPath(), - expr.getMajorType()); - @SuppressWarnings("resource") - ValueVector vector = TypeHelper.getNewVector(outputField, oContext.getAllocator()); - TypedFieldId id = container.add(vector); - valueExprs[i] = new ValueVectorWriteExpression(id, expr, true); + /* Populate the complex writers for complex exprs */ + if (expr instanceof DrillFuncHolderExpr && + ((DrillFuncHolderExpr) expr).getHolder().isComplexWriterFuncHolder()) { + // Need to process ComplexWriter function evaluation. + // Lazy initialization of the list of complex writers, if not done yet. + if (complexWriters == null) { + complexWriters = Lists.newArrayList(); + } else { + complexWriters.clear(); + } + // The reference name will be passed to ComplexWriter, used as the name of the output vector from the writer. + ((DrillFuncHolderExpr) expr).getFieldReference(ne.getRef()); + MaterializedField field = MaterializedField.create(ne.getRef().getAsNamePart().getName(), UntypedNullHolder.TYPE); + container.add(new UntypedNullVector(field, container.getAllocator())); + valueExprs[i] = expr; + } else { + final MaterializedField outputField = MaterializedField.create(ne.getRef().getLastSegment().getNameSegment().getPath(), + expr.getMajorType()); + @SuppressWarnings("resource") + ValueVector vector = TypeHelper.getNewVector(outputField, oContext.getAllocator()); + TypedFieldId id = container.add(vector); + valueExprs[i] = new ValueVectorWriteExpression(id, expr, true); + } } if (collector.hasErrors()) { @@ -331,6 +383,7 @@ private StreamingAggregator createAggregatorInternal() throws SchemaChangeExcept container.buildSchema(SelectionVectorMode.NONE); StreamingAggregator agg = context.getImplementationClass(cg); agg.setup(oContext, incoming, this); + allocateComplexWriters(); return agg; } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java index 02dd4de257..19499d67bb 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java @@ -18,6 +18,7 @@ package org.apache.drill.exec.planner.physical; import com.google.common.collect.Lists; +import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.util.ImmutableBitSet; import org.apache.drill.exec.planner.logical.DrillAggregateRel; import org.apache.drill.exec.planner.logical.RelOptHelper; @@ -58,7 +59,8 @@ public void onMatch(RelOptRuleCall call) { final DrillAggregateRel aggregate = call.rel(0); final RelNode input = call.rel(1); - if (aggregate.containsDistinctCall() || aggregate.getGroupCount() == 0) { + if (aggregate.containsDistinctCall() || aggregate.getGroupCount() == 0 + || requiresStreamingAgg(aggregate)) { // currently, don't use HashAggregate if any of the logical aggrs contains DISTINCT or // if there are no grouping keys return; @@ -101,6 +103,16 @@ public void onMatch(RelOptRuleCall call) { } } + private boolean requiresStreamingAgg(DrillAggregateRel aggregate) { + //If contains ANY_VALUE aggregate, using HashAgg would not work + for (AggregateCall agg : aggregate.getAggCallList()) { + if (agg.getAggregation().getName().equalsIgnoreCase("any_value")) { + return true; + } + } + return false; + } + private class TwoPhaseSubset extends SubsetTransformer<DrillAggregateRel, InvalidRelException> { final RelTrait distOnAllKeys; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/vector/complex/MapUtility.java b/exec/java-exec/src/main/java/org/apache/drill/exec/vector/complex/MapUtility.java index 543a6db98c..f4d29e9af9 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/vector/complex/MapUtility.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/vector/complex/MapUtility.java @@ -28,14 +28,14 @@ import org.apache.drill.exec.vector.complex.writer.BaseWriter; public class MapUtility { - private final static String TYPE_MISMATCH_ERROR = "Mappify/kvgen does not support heterogeneous value types. All values in the input map must be of the same type. The field [%s] has a differing type [%s]."; + private final static String TYPE_MISMATCH_ERROR = " does not support heterogeneous value types. All values in the input map must be of the same type. The field [%s] has a differing type [%s]."; /* * Function to read a value from the field reader, detect the type, construct the appropriate value holder * and use the value holder to write to the Map. */ // TODO : This should be templatized and generated using freemarker - public static void writeToMapFromReader(FieldReader fieldReader, BaseWriter.MapWriter mapWriter) { + public static void writeToMapFromReader(FieldReader fieldReader, BaseWriter.MapWriter mapWriter, String caller) { try { MajorType valueMajorType = fieldReader.getType(); MinorType valueMinorType = valueMajorType.getMinorType(); @@ -228,11 +228,311 @@ public static void writeToMapFromReader(FieldReader fieldReader, BaseWriter.MapW fieldReader.copyAsValue(mapWriter.list(MappifyUtility.fieldValue).list()); break; default: - throw new DrillRuntimeException(String.format("kvgen does not support input of type: %s", valueMinorType)); + throw new DrillRuntimeException(String.format(caller + + " does not support input of type: %s", valueMinorType)); } } catch (ClassCastException e) { final MaterializedField field = fieldReader.getField(); - throw new DrillRuntimeException(String.format(TYPE_MISMATCH_ERROR, field.getName(), field.getType())); + throw new DrillRuntimeException(String.format(caller + TYPE_MISMATCH_ERROR, field.getName(), field.getType())); + } + } + + public static void writeToMapFromReader(FieldReader fieldReader, BaseWriter.MapWriter mapWriter, + String fieldName, String caller) { + try { + MajorType valueMajorType = fieldReader.getType(); + MinorType valueMinorType = valueMajorType.getMinorType(); + boolean repeated = false; + + if (valueMajorType.getMode() == TypeProtos.DataMode.REPEATED) { + repeated = true; + } + + switch (valueMinorType) { + case TINYINT: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).tinyInt()); + } else { + fieldReader.copyAsValue(mapWriter.tinyInt(fieldName)); + } + break; + case SMALLINT: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).smallInt()); + } else { + fieldReader.copyAsValue(mapWriter.smallInt(fieldName)); + } + break; + case BIGINT: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).bigInt()); + } else { + fieldReader.copyAsValue(mapWriter.bigInt(fieldName)); + } + break; + case INT: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).integer()); + } else { + fieldReader.copyAsValue(mapWriter.integer(fieldName)); + } + break; + case UINT1: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).uInt1()); + } else { + fieldReader.copyAsValue(mapWriter.uInt1(fieldName)); + } + break; + case UINT2: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).uInt2()); + } else { + fieldReader.copyAsValue(mapWriter.uInt2(fieldName)); + } + break; + case UINT4: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).uInt4()); + } else { + fieldReader.copyAsValue(mapWriter.uInt4(fieldName)); + } + break; + case UINT8: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).uInt8()); + } else { + fieldReader.copyAsValue(mapWriter.uInt8(fieldName)); + } + break; + case DECIMAL9: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).decimal9()); + } else { + fieldReader.copyAsValue(mapWriter.decimal9(fieldName)); + } + break; + case DECIMAL18: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).decimal18()); + } else { + fieldReader.copyAsValue(mapWriter.decimal18(fieldName)); + } + break; + case DECIMAL28SPARSE: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).decimal28Sparse()); + } else { + fieldReader.copyAsValue(mapWriter.decimal28Sparse(fieldName)); + } + break; + case DECIMAL38SPARSE: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).decimal38Sparse()); + } else { + fieldReader.copyAsValue(mapWriter.decimal38Sparse(fieldName)); + } + break; + case VARDECIMAL: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).varDecimal(valueMajorType.getScale(), valueMajorType.getPrecision())); + } else { + fieldReader.copyAsValue(mapWriter.varDecimal(fieldName, valueMajorType.getScale(), valueMajorType.getPrecision())); + } + break; + case DATE: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).date()); + } else { + fieldReader.copyAsValue(mapWriter.date(fieldName)); + } + break; + case TIME: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).time()); + } else { + fieldReader.copyAsValue(mapWriter.time(fieldName)); + } + break; + case TIMESTAMP: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).timeStamp()); + } else { + fieldReader.copyAsValue(mapWriter.timeStamp(fieldName)); + } + break; + case INTERVAL: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).interval()); + } else { + fieldReader.copyAsValue(mapWriter.interval(fieldName)); + } + break; + case INTERVALDAY: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).intervalDay()); + } else { + fieldReader.copyAsValue(mapWriter.intervalDay(fieldName)); + } + break; + case INTERVALYEAR: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).intervalYear()); + } else { + fieldReader.copyAsValue(mapWriter.intervalYear(fieldName)); + } + break; + case FLOAT4: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).float4()); + } else { + fieldReader.copyAsValue(mapWriter.float4(fieldName)); + } + break; + case FLOAT8: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).float8()); + } else { + fieldReader.copyAsValue(mapWriter.float8(fieldName)); + } + break; + case BIT: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).bit()); + } else { + fieldReader.copyAsValue(mapWriter.bit(fieldName)); + } + break; + case VARCHAR: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).varChar()); + } else { + fieldReader.copyAsValue(mapWriter.varChar(fieldName)); + } + break; + case VARBINARY: + if (repeated) { + fieldReader.copyAsValue(mapWriter.list(fieldName).varBinary()); + } else { + fieldReader.copyAsValue(mapWriter.varBinary(fieldName)); + } + break; + case MAP: + if (valueMajorType.getMode() == TypeProtos.DataMode.REPEATED) { + fieldReader.copyAsValue(mapWriter.list(fieldName).map()); + } else { + fieldReader.copyAsValue(mapWriter.map(fieldName)); + } + break; + case LIST: + fieldReader.copyAsValue(mapWriter.list(fieldName).list()); + break; + default: + throw new DrillRuntimeException(String.format(caller + + " does not support input of type: %s", valueMinorType)); + } + } catch (ClassCastException e) { + final MaterializedField field = fieldReader.getField(); + throw new DrillRuntimeException(String.format(caller + TYPE_MISMATCH_ERROR, field.getName(), field.getType())); + } + } + + public static void writeToListFromReader(FieldReader fieldReader, BaseWriter.ListWriter listWriter, String caller) { + try { + MajorType valueMajorType = fieldReader.getType(); + MinorType valueMinorType = valueMajorType.getMinorType(); + boolean repeated = false; + + if (valueMajorType.getMode() == TypeProtos.DataMode.REPEATED) { + repeated = true; + } + + switch (valueMinorType) { + case TINYINT: + fieldReader.copyAsValue(listWriter.tinyInt()); + break; + case SMALLINT: + fieldReader.copyAsValue(listWriter.smallInt()); + break; + case BIGINT: + fieldReader.copyAsValue(listWriter.bigInt()); + break; + case INT: + fieldReader.copyAsValue(listWriter.integer()); + break; + case UINT1: + fieldReader.copyAsValue(listWriter.uInt1()); + break; + case UINT2: + fieldReader.copyAsValue(listWriter.uInt2()); + break; + case UINT4: + fieldReader.copyAsValue(listWriter.uInt4()); + break; + case UINT8: + fieldReader.copyAsValue(listWriter.uInt8()); + break; + case DECIMAL9: + fieldReader.copyAsValue(listWriter.decimal9()); + break; + case DECIMAL18: + fieldReader.copyAsValue(listWriter.decimal18()); + break; + case DECIMAL28SPARSE: + fieldReader.copyAsValue(listWriter.decimal28Sparse()); + break; + case DECIMAL38SPARSE: + fieldReader.copyAsValue(listWriter.decimal38Sparse()); + break; + case VARDECIMAL: + fieldReader.copyAsValue(listWriter.varDecimal(valueMajorType.getScale(), valueMajorType.getPrecision())); + break; + case DATE: + fieldReader.copyAsValue(listWriter.date()); + break; + case TIME: + fieldReader.copyAsValue(listWriter.time()); + break; + case TIMESTAMP: + fieldReader.copyAsValue(listWriter.timeStamp()); + break; + case INTERVAL: + fieldReader.copyAsValue(listWriter.interval()); + break; + case INTERVALDAY: + fieldReader.copyAsValue(listWriter.intervalDay()); + break; + case INTERVALYEAR: + fieldReader.copyAsValue(listWriter.intervalYear()); + break; + case FLOAT4: + fieldReader.copyAsValue(listWriter.float4()); + break; + case FLOAT8: + fieldReader.copyAsValue(listWriter.float8()); + break; + case BIT: + fieldReader.copyAsValue(listWriter.bit()); + break; + case VARCHAR: + fieldReader.copyAsValue(listWriter.varChar()); + break; + case VARBINARY: + fieldReader.copyAsValue(listWriter.varBinary()); + break; + case MAP: + fieldReader.copyAsValue(listWriter.map()); + break; + case LIST: + fieldReader.copyAsValue(listWriter.list()); + break; + default: + throw new DrillRuntimeException(String.format(caller + + " function does not support input of type: %s", valueMinorType)); + } + } catch (ClassCastException e) { + final MaterializedField field = fieldReader.getField(); + throw new DrillRuntimeException(String.format(caller + TYPE_MISMATCH_ERROR, field.getName(), field.getType())); } } } diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/agg/TestAggWithAnyValue.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/agg/TestAggWithAnyValue.java new file mode 100644 index 0000000000..37c0b52e90 --- /dev/null +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/agg/TestAggWithAnyValue.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.drill.exec.physical.impl.agg; + +import com.google.common.collect.Lists; +import org.apache.drill.exec.physical.config.StreamingAggregate; +import org.apache.drill.exec.physical.unit.PhysicalOpUnitTestBase; +import org.apache.drill.exec.util.JsonStringArrayList; +import org.apache.drill.test.BaseTestQuery; +import org.apache.drill.categories.OperatorTest; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import org.apache.drill.test.TestBuilder; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; + +import java.math.BigDecimal; +import java.util.List; + +@Category(OperatorTest.class) +@RunWith(Enclosed.class) +public class TestAggWithAnyValue { + + public static class TestAggWithAnyValueMultipleBatches extends PhysicalOpUnitTestBase { + + @Test + public void testStreamAggWithGroupBy() { + StreamingAggregate aggConf = new StreamingAggregate(null, parseExprs("age.`max`", "age"), parseExprs("any_value(a)", "any_a"), 2.0f); + List<String> inputJsonBatches = Lists.newArrayList( + "[{ \"age\": {\"min\":20, \"max\":60}, \"city\": \"San Bruno\", \"de\": \"987654321987654321987654321.10987654321\"," + + " \"a\": [{\"b\":50, \"c\":30},{\"b\":70, \"c\":40}], \"m\": [{\"n\": [10, 11, 12]}], \"f\": [{\"g\": {\"h\": [{\"k\": 70}, {\"k\": 80}]}}]," + + "\"p\": {\"q\": [21, 22, 23]}" + "}, " + + "{ \"age\": {\"min\":20, \"max\":60}, \"city\": \"Castro Valley\", \"de\": \"987654321987654321987654321.12987654321\"," + + " \"a\": [{\"b\":60, \"c\":40},{\"b\":80, \"c\":50}], \"m\": [{\"n\": [13, 14, 15]}], \"f\": [{\"g\": {\"h\": [{\"k\": 90}, {\"k\": 100}]}}]," + + "\"p\": {\"q\": [24, 25, 26]}" + "}]", + "[{ \"age\": {\"min\":43, \"max\":80}, \"city\": \"Palo Alto\", \"de\": \"987654321987654321987654321.00987654321\"," + + " \"a\": [{\"b\":10, \"c\":15}, {\"b\":20, \"c\":45}], \"m\": [{\"n\": [1, 2, 3]}], \"f\": [{\"g\": {\"h\": [{\"k\": 10}, {\"k\": 20}]}}]," + + "\"p\": {\"q\": [27, 28, 29]}" + "}, " + + "{ \"age\": {\"min\":43, \"max\":80}, \"city\": \"San Carlos\", \"de\": \"987654321987654321987654321.11987654321\"," + + " \"a\": [{\"b\":30, \"c\":25}, {\"b\":40, \"c\":55}], \"m\": [{\"n\": [4, 5, 6]}], \"f\": [{\"g\": {\"h\": [{\"k\": 30}, {\"k\": 40}]}}]," + + "\"p\": {\"q\": [30, 31, 32]}" + "}, " + + "{ \"age\": {\"min\":43, \"max\":80}, \"city\": \"Palo Alto\", \"de\": \"987654321987654321987654321.13987654321\"," + + " \"a\": [{\"b\":70, \"c\":85}, {\"b\":90, \"c\":145}], \"m\": [{\"n\": [7, 8, 9]}], \"f\": [{\"g\": {\"h\": [{\"k\": 50}, {\"k\": 60}]}}]," + + "\"p\": {\"q\": [33, 34, 35]}" + "}]"); + opTestBuilder() + .physicalOperator(aggConf) + .inputDataStreamJson(inputJsonBatches) + .baselineColumns("age", "any_a") + .baselineValues(60l, TestBuilder.listOf(TestBuilder.mapOf("b", 50l, "c", 30l), TestBuilder.mapOf("b", 70l, "c", 40l))) + .baselineValues(80l, TestBuilder.listOf(TestBuilder.mapOf("b", 10l, "c", 15l), TestBuilder.mapOf("b", 20l, "c", 45l))) + .go(); + } + } + + public static class TestAggWithAnyValueSingleBatch extends BaseTestQuery { + + @Test + public void testWithGroupBy() throws Exception { + String query = "select t1.age.`max` as age, count(*) as cnt, any_value(t1.a) as any_a, any_value(t1.city) as any_city, " + + "any_value(f) as any_f, any_value(m) as any_m, any_value(p) as any_p from cp.`store/json/test_anyvalue.json` t1 group by t1.age.`max`"; + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("age", "cnt", "any_a", "any_city", "any_f", "any_m", "any_p") + .baselineValues(60l, 2l, TestBuilder.listOf(TestBuilder.mapOf("b", 50l, "c", 30l), TestBuilder.mapOf("b", 70l, "c", 40l)), "San Bruno", + TestBuilder.listOf(TestBuilder.mapOf("g", TestBuilder.mapOf("h", TestBuilder.listOf(TestBuilder.mapOf("k", 70l), TestBuilder.mapOf("k", 80l))))), + TestBuilder.listOf(TestBuilder.mapOf("n", TestBuilder.listOf(10l, 11l, 12l))), + TestBuilder.mapOf("q", TestBuilder.listOf(21l, 22l, 23l))) + .baselineValues(80l, 3l, TestBuilder.listOf(TestBuilder.mapOf("b", 10l, "c", 15l), TestBuilder.mapOf("b", 20l, "c", 45l)), "Palo Alto", + TestBuilder.listOf(TestBuilder.mapOf("g", TestBuilder.mapOf("h", TestBuilder.listOf(TestBuilder.mapOf("k", 10l), TestBuilder.mapOf("k", 20l))))), + TestBuilder.listOf(TestBuilder.mapOf("n", TestBuilder.listOf(1l, 2l, 3l))), + TestBuilder.mapOf("q", TestBuilder.listOf(27l, 28l, 29l))) + .go(); + } + + @Test + public void testWithoutGroupBy() throws Exception { + String query = "select count(*) as cnt, any_value(t1.a) as any_a, any_value(t1.city) as any_city, " + + "any_value(f) as any_f, any_value(m) as any_m, any_value(p) as any_p from cp.`store/json/test_anyvalue.json` t1"; + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("cnt", "any_a", "any_city", "any_f", "any_m", "any_p") + .baselineValues(5l, TestBuilder.listOf(TestBuilder.mapOf("b", 10l, "c", 15l), TestBuilder.mapOf("b", 20l, "c", 45l)), "Palo Alto", + TestBuilder.listOf(TestBuilder.mapOf("g", TestBuilder.mapOf("h", TestBuilder.listOf(TestBuilder.mapOf("k", 10l), TestBuilder.mapOf("k", 20l))))), + TestBuilder.listOf(TestBuilder.mapOf("n", TestBuilder.listOf(1l, 2l, 3l))), + TestBuilder.mapOf("q", TestBuilder.listOf(27l, 28l, 29l))) + .go(); + } + + @Test + public void testDecimalWithGroupBy() throws Exception { + String query = "select t1.age.`max` as age, any_value(cast(t1.de as decimal(38, 11))) as any_decimal " + + "from cp.`store/json/test_anyvalue.json` t1 group by t1.age.`max`"; + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("age", "any_decimal") + .baselineValues(60l, new BigDecimal("987654321987654321987654321.10987654321")) + .baselineValues(80l, new BigDecimal("987654321987654321987654321.00987654321")) + .go(); + } + + @Test + public void testRepeatedDecimalWithGroupBy() throws Exception { + JsonStringArrayList<BigDecimal> ints = new JsonStringArrayList<>(); + ints.add(new BigDecimal("999999.999")); + ints.add(new BigDecimal("-999999.999")); + ints.add(new BigDecimal("0.000")); + + JsonStringArrayList<BigDecimal> longs = new JsonStringArrayList<>(); + longs.add(new BigDecimal("999999999.999999999")); + longs.add(new BigDecimal("-999999999.999999999")); + longs.add(new BigDecimal("0.000000000")); + + JsonStringArrayList<BigDecimal> fixedLen = new JsonStringArrayList<>(); + fixedLen.add(new BigDecimal("999999999999.999999")); + fixedLen.add(new BigDecimal("-999999999999.999999")); + fixedLen.add(new BigDecimal("0.000000")); + + String query = "select any_value(decimal_int32) as any_dec_32, any_value(decimal_int64) as any_dec_64," + + " any_value(decimal_fixedLen) as any_dec_fixed, any_value(decimal_binary) as any_dec_bin" + + " from cp.`parquet/repeatedIntLondFixedLenBinaryDecimal.parquet`"; + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("any_dec_32", "any_dec_64", "any_dec_fixed", "any_dec_bin") + .baselineValues(ints, longs, fixedLen, fixedLen) + .go(); + } + } +} \ No newline at end of file diff --git a/exec/java-exec/src/test/resources/store/json/test_anyvalue.json b/exec/java-exec/src/test/resources/store/json/test_anyvalue.json new file mode 100644 index 0000000000..8e7bef983e --- /dev/null +++ b/exec/java-exec/src/test/resources/store/json/test_anyvalue.json @@ -0,0 +1,50 @@ +{ + "age": {"min":43, "max":80}, + "city": "Palo Alto", + "de": "987654321987654321987654321.00987654321", + "delist": ["987654321987654321987654321.19987654321", "987654321987654321987654321.20987654321"], + "a": [{"b":10, "c":15}, {"b":20, "c":45}], + "m": [{"n": [1, 2, 3]}], + "f": [{"g": {"h": [{"k": 10}, {"k": 20}]}}], + "p": {"q" : [27, 28, 29]} +} +{ + "age": {"min":20, "max":60}, + "city": "San Bruno", + "de": "987654321987654321987654321.10987654321", + "delist": ["987654321987654321987654321.17987654321", "987654321987654321987654321.18987654321"], + "a": [{"b":50, "c":30},{"b":70, "c":40}], + "m": [{"n": [10, 11, 12]}], + "f": [{"g": {"h": [{"k": 70}, {"k": 80}]}}], + "p": {"q" : [21, 22, 23]} +} +{ + "age": {"min":43, "max":80}, + "city": "San Carlos", + "de": "987654321987654321987654321.11987654321", + "delist": ["987654321987654321987654321.11987654321", "987654321987654321987654321.12987654321"], + "a": [{"b":30, "c":25}, {"b":40, "c":55}], + "m": [{"n": [4, 5, 6]}], + "f": [{"g": {"h": [{"k": 30}, {"k": 40}]}}], + "p": {"q" : [30, 31, 32]} +} +{ + "age": {"min":20, "max":60}, + "city": "Castro Valley", + "de": "987654321987654321987654321.12987654321", + "delist": ["987654321987654321987654321.13987654321", "987654321987654321987654321.14987654321"], + "a": [{"b":60, "c":40},{"b":80, "c":50}], + "m": [{"n": [13, 14, 15]}], + "f": [{"g": {"h": [{"k": 90}, {"k": 100}]}}], + "p": {"q" : [24, 25, 26]} +} +{ + "age": {"min":43, "max":80}, + "city": "Palo Alto", + "de": "987654321987654321987654321.13987654321", + "delist": ["987654321987654321987654321.15987654321", "987654321987654321987654321.16987654321"], + "a": [{"b":70, "c":85}, {"b":90, "c":145}], + "m": [{"n": [7, 8, 9]}], + "f": [{"g": {"h": [{"k": 50}, {"k": 60}]}}], + "p": {"q" : [33, 34, 35]} +} \ No newline at end of file diff --git a/exec/vector/src/main/codegen/templates/RepeatedValueVectors.java b/exec/vector/src/main/codegen/templates/RepeatedValueVectors.java index 4e6edb580d..037332f073 100644 --- a/exec/vector/src/main/codegen/templates/RepeatedValueVectors.java +++ b/exec/vector/src/main/codegen/templates/RepeatedValueVectors.java @@ -307,6 +307,7 @@ public void get(int index, Repeated${minor.class}Holder holder) { holder.start = offsets.getAccessor().get(index); holder.end = offsets.getAccessor().get(index+1); holder.vector = values; + holder.reader = reader; } public void get(int index, int positionIndex, ${minor.class}Holder holder) { diff --git a/exec/vector/src/main/codegen/templates/ValueHolders.java b/exec/vector/src/main/codegen/templates/ValueHolders.java index 7635895be3..8562d1b6bc 100644 --- a/exec/vector/src/main/codegen/templates/ValueHolders.java +++ b/exec/vector/src/main/codegen/templates/ValueHolders.java @@ -49,16 +49,17 @@ /** The first index (inclusive) into the Vector. **/ public int start; - + /** The last index (exclusive) into the Vector. **/ public int end; - + /** The Vector holding the actual values. **/ public ${minor.class}Vector vector; - - <#else> + + public FieldReader reader; +<#else> public static final int WIDTH = ${type.width}; - + <#if mode.name == "Optional">public int isSet;</#if> <#assign fields = minor.fields!type.fields /> <#list fields as field> diff --git a/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedListHolder.java b/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedListHolder.java index 52f590ac2a..ce7e34dfe3 100644 --- a/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedListHolder.java +++ b/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedListHolder.java @@ -20,6 +20,7 @@ import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.types.Types; import org.apache.drill.exec.vector.complex.ListVector; +import org.apache.drill.exec.vector.complex.reader.FieldReader; public final class RepeatedListHolder implements ValueHolder{ @@ -36,4 +37,5 @@ /** The Vector holding the actual values. **/ public ListVector vector; + public FieldReader reader; } diff --git a/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedMapHolder.java b/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedMapHolder.java index f8acaebee2..516d1358bf 100644 --- a/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedMapHolder.java +++ b/exec/vector/src/main/java/org/apache/drill/exec/expr/holders/RepeatedMapHolder.java @@ -20,6 +20,7 @@ import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.types.Types; import org.apache.drill.exec.vector.complex.MapVector; +import org.apache.drill.exec.vector.complex.reader.FieldReader; public final class RepeatedMapHolder implements ValueHolder{ @@ -38,4 +39,6 @@ /** The Vector holding the actual values. **/ public MapVector vector; + public FieldReader reader; + } diff --git a/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprLexer.g b/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprLexer.g index 2b497a1959..93dba9478f 100644 --- a/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprLexer.g +++ b/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprLexer.g @@ -37,6 +37,7 @@ When : 'when'; Cast: 'cast'; Convert : 'convert_' ('from' | 'to'); +AnyValue : 'any_value' | 'ANY_VALUE'; Nullable: 'nullable'; Repeat: 'repeat'; As: 'as'; diff --git a/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprParser.g b/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprParser.g index e73bdea00e..78a7cc3297 100644 --- a/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprParser.g +++ b/logical/src/main/antlr3/org/apache/drill/common/expression/parser/ExprParser.g @@ -81,6 +81,10 @@ convertCall returns [LogicalExpression e] { $e = FunctionCallFactory.createConvert($Convert.text, $String.text, $expression.e, pos($Convert));} ; +anyValueCall returns [LogicalExpression e] + : AnyValue OParen exprList? CParen {$e = FunctionCallFactory.createExpression($AnyValue.text, pos($AnyValue), $exprList.listE); } + ; + castCall returns [LogicalExpression e] @init{ List<LogicalExpression> exprs = new ArrayList<LogicalExpression>(); @@ -313,6 +317,7 @@ arraySegment returns [PathSegment seg] lookup returns [LogicalExpression e] : functionCall {$e = $functionCall.e ;} | convertCall {$e = $convertCall.e; } + | anyValueCall {$e = $anyValueCall.e; } | castCall {$e = $castCall.e; } | pathSegment {$e = new SchemaPath($pathSegment.seg, pos($pathSegment.start) ); } | String {$e = new ValueExpressions.QuotedString($String.text, $String.text.length(), pos($String) ); } diff --git a/logical/src/main/java/org/apache/drill/common/expression/AnyValueExpression.java b/logical/src/main/java/org/apache/drill/common/expression/AnyValueExpression.java new file mode 100644 index 0000000000..4dff14752f --- /dev/null +++ b/logical/src/main/java/org/apache/drill/common/expression/AnyValueExpression.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.common.expression; + +import org.apache.drill.common.expression.visitors.ExprVisitor; +import org.apache.drill.common.types.TypeProtos.MajorType; +import org.apache.drill.common.types.Types; + +import java.util.Collections; +import java.util.Iterator; + +public class AnyValueExpression extends LogicalExpressionBase implements Iterable<LogicalExpression>{ + + static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(AnyValueExpression.class); + + public static final String ANY_VALUE = "any_value"; + private final LogicalExpression input; + private final MajorType type; + + /** + * @param input + * @param pos + */ + public AnyValueExpression(LogicalExpression input, ExpressionPosition pos) { + super(pos); + this.input = input; + this.type = input.getMajorType(); + } + + @Override + public <T, V, E extends Exception> T accept(ExprVisitor<T, V, E> visitor, V value) throws E { + return visitor.visitAnyValueExpression(this, value); + } + + @Override + public Iterator<LogicalExpression> iterator() { + return Collections.singleton(input).iterator(); + } + + public LogicalExpression getInput() { + return input; + } + + @Override + public MajorType getMajorType() { + return type; + } + + @Override + public String toString() { + return "AnyValueExpression [input=" + input + ", type=" + Types.toString(type) + "]"; + } +} + diff --git a/logical/src/main/java/org/apache/drill/common/expression/ExpressionStringBuilder.java b/logical/src/main/java/org/apache/drill/common/expression/ExpressionStringBuilder.java index f09f887412..fb5323b2dd 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/ExpressionStringBuilder.java +++ b/logical/src/main/java/org/apache/drill/common/expression/ExpressionStringBuilder.java @@ -251,6 +251,14 @@ public Void visitConvertExpression(ConvertExpression e, StringBuilder sb) throws return null; } + @Override + public Void visitAnyValueExpression(AnyValueExpression e, StringBuilder sb) throws RuntimeException { + sb.append("any("); + e.getInput().accept(this, sb); + sb.append(")"); + return null; + } + @Override public Void visitCastExpression(CastExpression e, StringBuilder sb) throws RuntimeException { MajorType mt = e.getMajorType(); diff --git a/logical/src/main/java/org/apache/drill/common/expression/FunctionCallFactory.java b/logical/src/main/java/org/apache/drill/common/expression/FunctionCallFactory.java index 7d9f9a1077..513da4cc4a 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/FunctionCallFactory.java +++ b/logical/src/main/java/org/apache/drill/common/expression/FunctionCallFactory.java @@ -85,6 +85,14 @@ public static LogicalExpression createConvert(String function, String conversion return new ConvertExpression(function, conversionType, expr, ep); } + public static LogicalExpression createAnyValue(ExpressionPosition ep, LogicalExpression expr) { + return new AnyValueExpression(expr, ep); + } + + public static LogicalExpression createAnyValue(String functionName, List<LogicalExpression> args) { + return createExpression(functionName, args); + } + public static LogicalExpression createExpression(String functionName, List<LogicalExpression> args){ return createExpression(functionName, ExpressionPosition.UNKNOWN, args); } diff --git a/logical/src/main/java/org/apache/drill/common/expression/visitors/AbstractExprVisitor.java b/logical/src/main/java/org/apache/drill/common/expression/visitors/AbstractExprVisitor.java index 8458968493..18483ce791 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/visitors/AbstractExprVisitor.java +++ b/logical/src/main/java/org/apache/drill/common/expression/visitors/AbstractExprVisitor.java @@ -17,6 +17,7 @@ */ package org.apache.drill.common.expression.visitors; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -164,6 +165,11 @@ public T visitConvertExpression(ConvertExpression e, VAL value) throws EXCEP { return visitUnknown(e, value); } + @Override + public T visitAnyValueExpression(AnyValueExpression e, VAL value) throws EXCEP { + return visitUnknown(e, value); + } + @Override public T visitNullConstant(TypedNullConstant e, VAL value) throws EXCEP { return visitUnknown(e, value); diff --git a/logical/src/main/java/org/apache/drill/common/expression/visitors/AggregateChecker.java b/logical/src/main/java/org/apache/drill/common/expression/visitors/AggregateChecker.java index 2e6b60b5e3..ac46e42fc4 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/visitors/AggregateChecker.java +++ b/logical/src/main/java/org/apache/drill/common/expression/visitors/AggregateChecker.java @@ -17,6 +17,7 @@ */ package org.apache.drill.common.expression.visitors; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -176,6 +177,11 @@ public Boolean visitConvertExpression(ConvertExpression e, ErrorCollector errors return e.getInput().accept(this, errors); } + @Override + public Boolean visitAnyValueExpression(AnyValueExpression e, ErrorCollector errors) throws RuntimeException { + return e.getInput().accept(this, errors); + } + @Override public Boolean visitDateConstant(DateExpression intExpr, ErrorCollector errors) { return false; diff --git a/logical/src/main/java/org/apache/drill/common/expression/visitors/ConditionalExprOptimizer.java b/logical/src/main/java/org/apache/drill/common/expression/visitors/ConditionalExprOptimizer.java index 05e3a7364e..1b6eab79aa 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/visitors/ConditionalExprOptimizer.java +++ b/logical/src/main/java/org/apache/drill/common/expression/visitors/ConditionalExprOptimizer.java @@ -21,6 +21,7 @@ import java.util.Comparator; import java.util.List; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -100,6 +101,12 @@ public LogicalExpression visitConvertExpression(ConvertExpression cast, Void val + "It should have been converted to FunctionHolderExpression in materialization"); } + @Override + public LogicalExpression visitAnyValueExpression(AnyValueExpression cast, Void value) throws RuntimeException { + throw new UnsupportedOperationException("AnyValueExpression is not expected here. " + + "It should have been converted to FunctionHolderExpression in materialization"); + } + private static Comparator<LogicalExpression> costComparator = new Comparator<LogicalExpression> () { public int compare(LogicalExpression e1, LogicalExpression e2) { return e1.getCumulativeCost() <= e2.getCumulativeCost() ? -1 : 1; diff --git a/logical/src/main/java/org/apache/drill/common/expression/visitors/ConstantChecker.java b/logical/src/main/java/org/apache/drill/common/expression/visitors/ConstantChecker.java index fbe7d721b7..a7648ac177 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/visitors/ConstantChecker.java +++ b/logical/src/main/java/org/apache/drill/common/expression/visitors/ConstantChecker.java @@ -17,6 +17,7 @@ */ package org.apache.drill.common.expression.visitors; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -201,6 +202,11 @@ public Boolean visitConvertExpression(ConvertExpression e, ErrorCollector value) return e.getInput().accept(this, value); } + @Override + public Boolean visitAnyValueExpression(AnyValueExpression e, ErrorCollector value) throws RuntimeException { + return e.getInput().accept(this, value); + } + @Override public Boolean visitNullConstant(TypedNullConstant e, ErrorCollector value) throws RuntimeException { return true; diff --git a/logical/src/main/java/org/apache/drill/common/expression/visitors/ExprVisitor.java b/logical/src/main/java/org/apache/drill/common/expression/visitors/ExprVisitor.java index c065bc8835..cea83d8718 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/visitors/ExprVisitor.java +++ b/logical/src/main/java/org/apache/drill/common/expression/visitors/ExprVisitor.java @@ -17,6 +17,7 @@ */ package org.apache.drill.common.expression.visitors; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -75,4 +76,5 @@ T visitConvertExpression(ConvertExpression e, VAL value) throws EXCEP; T visitParameter(ParameterExpression e, VAL value) throws EXCEP; T visitTypedFieldExpr(TypedFieldExpr e, VAL value) throws EXCEP; + T visitAnyValueExpression(AnyValueExpression e, VAL value) throws EXCEP; } diff --git a/logical/src/main/java/org/apache/drill/common/expression/visitors/ExpressionValidator.java b/logical/src/main/java/org/apache/drill/common/expression/visitors/ExpressionValidator.java index b3074fcb1d..df72a34151 100644 --- a/logical/src/main/java/org/apache/drill/common/expression/visitors/ExpressionValidator.java +++ b/logical/src/main/java/org/apache/drill/common/expression/visitors/ExpressionValidator.java @@ -17,6 +17,7 @@ */ package org.apache.drill.common.expression.visitors; +import org.apache.drill.common.expression.AnyValueExpression; import org.apache.drill.common.expression.BooleanOperator; import org.apache.drill.common.expression.CastExpression; import org.apache.drill.common.expression.ConvertExpression; @@ -238,6 +239,12 @@ public Void visitConvertExpression(ConvertExpression e, ErrorCollector value) return e.getInput().accept(this, value); } + @Override + public Void visitAnyValueExpression(AnyValueExpression e, ErrorCollector value) + throws RuntimeException { + return e.getInput().accept(this, value); + } + @Override public Void visitParameter(ValueExpressions.ParameterExpression e, ErrorCollector value) throws RuntimeException { return null; ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > ANY_VALUE aggregate function > ---------------------------- > > Key: DRILL-6375 > URL: https://issues.apache.org/jira/browse/DRILL-6375 > Project: Apache Drill > Issue Type: New Feature > Components: Functions - Drill > Affects Versions: 1.13.0 > Reporter: Gautam Kumar Parai > Assignee: Gautam Kumar Parai > Priority: Major > Labels: ready-to-commit > Fix For: 1.14.0 > > > We had discussions on the Apache Calcite [1] and Apache Drill [2] mailing > lists regarding an equivalent for DISTINCT ON. The community seems to prefer > the ANY_VALUE. This Jira is a placeholder for implementing the ANY_VALUE > aggregate function in Apache Drill. We should also eventually contribute it > to Apache Calcite. > [1]https://lists.apache.org/thread.html/f2007a489d3a5741875bcc8a1edd8d5c3715e5114ac45058c3b3a42d@%3Cdev.calcite.apache.org%3E > [2]https://lists.apache.org/thread.html/2517eef7410aed4e88b9515f7e4256335215c1ad39a2676a08d21cb9@%3Cdev.drill.apache.org%3E -- This message was sent by Atlassian JIRA (v7.6.3#76005)