HIVE-10785 : Support aggregate push down through joins (Ashutosh Chauhan via Jesus Camacho Rodriguez)
Project: http://git-wip-us.apache.org/repos/asf/hive/repo Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/68d6cfda Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/68d6cfda Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/68d6cfda Branch: refs/heads/beeline-cli Commit: 68d6cfda78b3ec6b42cf0d42df62aa1f2716d414 Parents: 1528135 Author: Ashutosh Chauhan <[email protected]> Authored: Thu Sep 17 21:49:00 2015 -0800 Committer: Ashutosh Chauhan <[email protected]> Committed: Thu Sep 24 13:58:50 2015 -0700 ---------------------------------------------------------------------- .../org/apache/hadoop/hive/conf/HiveConf.java | 2 +- .../hadoop/hive/ql/exec/FunctionRegistry.java | 3 +- .../functions/HiveSqlCountAggFunction.java | 72 + .../functions/HiveSqlMinMaxAggFunction.java | 49 + .../functions/HiveSqlSumAggFunction.java | 125 ++ .../rules/HiveAggregateJoinTransposeRule.java | 372 +++++ .../translator/SqlFunctionConverter.java | 40 +- .../hadoop/hive/ql/parse/CalcitePlanner.java | 5 + .../hive/ql/udf/generic/GenericUDAFSum.java | 2 +- .../udf/generic/GenericUDAFSumEmptyIsZero.java | 63 + .../clientpositive/groupby_join_pushdown.q | 55 + .../clientpositive/groupby_join_pushdown.q.out | 1522 ++++++++++++++++++ .../results/clientpositive/show_functions.q.out | 1 + 13 files changed, 2297 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java ---------------------------------------------------------------------- diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index f3e2168..dffdb5c 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -814,7 +814,7 @@ public class HiveConf extends Configuration { + " expressed as multiple of Local FS write cost"), HIVE_CBO_COST_MODEL_HDFS_READ("hive.cbo.costmodel.hdfs.read", "1.5", "Default cost of reading a byte from HDFS;" + " expressed as multiple of Local FS read cost"), - + AGGR_JOIN_TRANSPOSE("hive.transpose.aggr.join", false, "push aggregates through join"), // hive.mapjoin.bucket.cache.size has been replaced by hive.smbjoin.cache.row, // need to remove by hive .13. Also, do not change default (see SMB operator) http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index f1fe30d..218b2df 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -370,6 +370,7 @@ public final class FunctionRegistry { system.registerGenericUDAF("min", new GenericUDAFMin()); system.registerGenericUDAF("sum", new GenericUDAFSum()); + system.registerGenericUDAF("$SUM0", new GenericUDAFSumEmptyIsZero()); system.registerGenericUDAF("count", new GenericUDAFCount()); system.registerGenericUDAF("avg", new GenericUDAFAverage()); system.registerGenericUDAF("std", new GenericUDAFStd()); @@ -960,7 +961,7 @@ public final class FunctionRegistry { GenericUDAFParameterInfo paramInfo = new SimpleGenericUDAFParameterInfo( args, isDistinct, isAllColumns); - + GenericUDAFEvaluator udafEvaluator; if (udafResolver instanceof GenericUDAFResolver2) { udafEvaluator = http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java new file mode 100644 index 0000000..7937040 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.functions; + +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlSplittableAggFunction; +import org.apache.calcite.sql.SqlSplittableAggFunction.CountSplitter; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlOperandTypeInference; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableIntList; + +public class HiveSqlCountAggFunction extends SqlAggFunction { + + final SqlReturnTypeInference returnTypeInference; + final SqlOperandTypeInference operandTypeInference; + final SqlOperandTypeChecker operandTypeChecker; + + public HiveSqlCountAggFunction(SqlReturnTypeInference returnTypeInference, + SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) { + super( + "count", + SqlKind.OTHER_FUNCTION, + returnTypeInference, + operandTypeInference, + operandTypeChecker, + SqlFunctionCategory.NUMERIC); + this.returnTypeInference = returnTypeInference; + this.operandTypeChecker = operandTypeChecker; + this.operandTypeInference = operandTypeInference; + } + + @Override + public <T> T unwrap(Class<T> clazz) { + if (clazz == SqlSplittableAggFunction.class) { + return clazz.cast(new HiveCountSplitter()); + } + return super.unwrap(clazz); + } + + class HiveCountSplitter extends CountSplitter { + + @Override + public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) { + + return AggregateCall.create( + new HiveSqlCountAggFunction(returnTypeInference, operandTypeInference, operandTypeChecker), + false, ImmutableIntList.of(), -1, + typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true), "count"); + } + } +} http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlMinMaxAggFunction.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlMinMaxAggFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlMinMaxAggFunction.java new file mode 100644 index 0000000..77dca1f --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlMinMaxAggFunction.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.functions; + +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlSplittableAggFunction; +import org.apache.calcite.sql.SqlSplittableAggFunction.SelfSplitter; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlOperandTypeInference; +import org.apache.calcite.sql.type.SqlReturnTypeInference; + +public class HiveSqlMinMaxAggFunction extends SqlAggFunction { + + public HiveSqlMinMaxAggFunction(SqlReturnTypeInference returnTypeInference, + SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker, boolean isMin) { + super( + isMin ? "min" : "max", + SqlKind.OTHER_FUNCTION, + returnTypeInference, + operandTypeInference, + operandTypeChecker, + SqlFunctionCategory.NUMERIC); + } + + @Override + public <T> T unwrap(Class<T> clazz) { + if (clazz == SqlSplittableAggFunction.class) { + return clazz.cast(SelfSplitter.INSTANCE); + } + return super.unwrap(clazz); + } +} http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java new file mode 100644 index 0000000..8f62970 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.functions; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlSplittableAggFunction; +import org.apache.calcite.sql.SqlSplittableAggFunction.SumSplitter; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlOperandTypeInference; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableIntList; + +import com.google.common.collect.ImmutableList; + +/** + * <code>Sum</code> is an aggregator which returns the sum of the values which + * go into it. It has precisely one argument of numeric type (<code>int</code>, + * <code>long</code>, <code>float</code>, <code>double</code>), and the result + * is the same type. + */ +public class HiveSqlSumAggFunction extends SqlAggFunction { + + final SqlReturnTypeInference returnTypeInference; + final SqlOperandTypeInference operandTypeInference; + final SqlOperandTypeChecker operandTypeChecker; + + //~ Constructors ----------------------------------------------------------- + + public HiveSqlSumAggFunction(SqlReturnTypeInference returnTypeInference, + SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) { + super( + "sum", + SqlKind.OTHER_FUNCTION, + returnTypeInference, + operandTypeInference, + operandTypeChecker, + SqlFunctionCategory.NUMERIC); + this.returnTypeInference = returnTypeInference; + this.operandTypeChecker = operandTypeChecker; + this.operandTypeInference = operandTypeInference; + } + + //~ Methods ---------------------------------------------------------------- + + + @Override + public <T> T unwrap(Class<T> clazz) { + if (clazz == SqlSplittableAggFunction.class) { + return clazz.cast(new HiveSumSplitter()); + } + return super.unwrap(clazz); + } + + class HiveSumSplitter extends SumSplitter { + + @Override + public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) { + RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true); + return AggregateCall.create( + new HiveSqlCountAggFunction(ReturnTypes.explicit(countRetType), operandTypeInference, operandTypeChecker), + false, ImmutableIntList.of(), -1, countRetType, "count"); + } + + @Override + public AggregateCall topSplit(RexBuilder rexBuilder, + Registry<RexNode> extra, int offset, RelDataType inputRowType, + AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) { + final List<RexNode> merges = new ArrayList<>(); + final List<RelDataTypeField> fieldList = inputRowType.getFieldList(); + if (leftSubTotal >= 0) { + final RelDataType type = fieldList.get(leftSubTotal).getType(); + merges.add(rexBuilder.makeInputRef(type, leftSubTotal)); + } + if (rightSubTotal >= 0) { + final RelDataType type = fieldList.get(rightSubTotal).getType(); + merges.add(rexBuilder.makeInputRef(type, rightSubTotal)); + } + RexNode node; + switch (merges.size()) { + case 1: + node = merges.get(0); + break; + case 2: + node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges); + node = rexBuilder.makeAbstractCast(aggregateCall.type, node); + break; + default: + throw new AssertionError("unexpected count " + merges); + } + int ordinal = extra.register(node); + return AggregateCall.create(new HiveSqlSumAggFunction(returnTypeInference, operandTypeInference, operandTypeChecker), + false, ImmutableList.of(ordinal), -1, aggregateCall.type, aggregateCall.name); + } + } +} + +// End SqlSumAggFunction.java http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java new file mode 100644 index 0000000..211b6fa --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.rules.AggregateJoinTransposeRule; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlSplittableAggFunction; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.mapping.Mapping; +import org.apache.calcite.util.mapping.Mappings; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * Planner rule that pushes an + * {@link org.apache.calcite.rel.core.Aggregate} + * past a {@link org.apache.calcite.rel.core.Join}. + */ +public class HiveAggregateJoinTransposeRule extends AggregateJoinTransposeRule { + + /** Extended instance of the rule that can push down aggregate functions. */ + public static final HiveAggregateJoinTransposeRule INSTANCE = + new HiveAggregateJoinTransposeRule(HiveAggregate.class, HiveAggregate.HIVE_AGGR_REL_FACTORY, + HiveJoin.class, HiveJoin.HIVE_JOIN_FACTORY, HiveProject.DEFAULT_PROJECT_FACTORY, true); + + private final RelFactories.AggregateFactory aggregateFactory; + + private final RelFactories.JoinFactory joinFactory; + + private final RelFactories.ProjectFactory projectFactory; + + private final boolean allowFunctions; + + /** Creates an AggregateJoinTransposeRule that may push down functions. */ + private HiveAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, + RelFactories.AggregateFactory aggregateFactory, + Class<? extends Join> joinClass, + RelFactories.JoinFactory joinFactory, + RelFactories.ProjectFactory projectFactory, + boolean allowFunctions) { + super(aggregateClass, aggregateFactory, joinClass, joinFactory, projectFactory, true); + this.aggregateFactory = aggregateFactory; + this.joinFactory = joinFactory; + this.projectFactory = projectFactory; + this.allowFunctions = allowFunctions; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final Join join = call.rel(1); + final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + + // If any aggregate functions do not support splitting, bail out + // If any aggregate call has a filter, bail out + for (AggregateCall aggregateCall : aggregate.getAggCallList()) { + if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) + == null) { + return; + } + if (aggregateCall.filterArg >= 0) { + return; + } + } + + // If it is not an inner join, we do not push the + // aggregate operator + if (join.getJoinType() != JoinRelType.INNER) { + return; + } + + if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) { + return; + } + + // Do the columns used by the join appear in the output of the aggregate? + final ImmutableBitSet aggregateColumns = aggregate.getGroupSet(); + final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, + RelMetadataQuery.getPulledUpPredicates(join).pulledUpPredicates); + final ImmutableBitSet joinColumns = + RelOptUtil.InputFinder.bits(join.getCondition()); + final boolean allColumnsInAggregate = + keyColumns.contains(joinColumns); + final ImmutableBitSet belowAggregateColumns = + aggregateColumns.union(joinColumns); + + // Split join condition + final List<Integer> leftKeys = Lists.newArrayList(); + final List<Integer> rightKeys = Lists.newArrayList(); + RexNode nonEquiConj = + RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), + join.getCondition(), leftKeys, rightKeys); + // If it contains non-equi join conditions, we bail out + if (!nonEquiConj.isAlwaysTrue()) { + return; + } + + // Push each aggregate function down to each side that contains all of its + // arguments. Note that COUNT(*), because it has no arguments, can go to + // both sides. + final Map<Integer, Integer> map = new HashMap<>(); + final List<Side> sides = new ArrayList<>(); + int uniqueCount = 0; + int offset = 0; + int belowOffset = 0; + for (int s = 0; s < 2; s++) { + final Side side = new Side(); + final RelNode joinInput = join.getInput(s); + int fieldCount = joinInput.getRowType().getFieldCount(); + final ImmutableBitSet fieldSet = + ImmutableBitSet.range(offset, offset + fieldCount); + final ImmutableBitSet belowAggregateKeyNotShifted = + belowAggregateColumns.intersect(fieldSet); + for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) { + map.put(c.e, belowOffset + c.i); + } + final ImmutableBitSet belowAggregateKey = + belowAggregateKeyNotShifted.shift(-offset); + final boolean unique; + if (!allowFunctions) { + assert aggregate.getAggCallList().isEmpty(); + // If there are no functions, it doesn't matter as much whether we + // aggregate the inputs before the join, because there will not be + // any functions experiencing a cartesian product effect. + // + // But finding out whether the input is already unique requires a call + // to areColumnsUnique that currently (until [CALCITE-794] "Detect + // cycles when computing statistics" is fixed) places a heavy load on + // the metadata system. + // + // So we choose to imagine the the input is already unique, which is + // untrue but harmless. + // + unique = true; + } else { + final Boolean unique0 = + RelMetadataQuery.areColumnsUnique(joinInput, belowAggregateKey); + unique = unique0 != null && unique0; + } + if (unique) { + ++uniqueCount; + side.newInput = joinInput; + } else { + List<AggregateCall> belowAggCalls = new ArrayList<>(); + final SqlSplittableAggFunction.Registry<AggregateCall> + belowAggCallRegistry = registry(belowAggCalls); + final Mappings.TargetMapping mapping = + s == 0 + ? Mappings.createIdentity(fieldCount) + : Mappings.createShiftMapping(fieldCount + offset, 0, offset, + fieldCount); + for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) { + final SqlAggFunction aggregation = aggCall.e.getAggregation(); + final SqlSplittableAggFunction splitter = + Preconditions.checkNotNull( + aggregation.unwrap(SqlSplittableAggFunction.class)); + final AggregateCall call1; + if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) { + call1 = splitter.split(aggCall.e, mapping); + } else { + call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e); + } + if (call1 != null) { + side.split.put(aggCall.i, + belowAggregateKey.cardinality() + + belowAggCallRegistry.register(call1)); + } + } + side.newInput = aggregateFactory.createAggregate(joinInput, false, + belowAggregateKey, null, belowAggCalls); + } + offset += fieldCount; + belowOffset += side.newInput.getRowType().getFieldCount(); + sides.add(side); + } + + if (uniqueCount == 2) { + // Both inputs to the join are unique. There is nothing to be gained by + // this rule. In fact, this aggregate+join may be the result of a previous + // invocation of this rule; if we continue we might loop forever. + return; + } + + // Update condition + final Mapping mapping = (Mapping) Mappings.target( + new Function<Integer, Integer>() { + @Override + public Integer apply(Integer a0) { + return map.get(a0); + } + }, + join.getRowType().getFieldCount(), + belowOffset); + final RexNode newCondition = + RexUtil.apply(mapping, join.getCondition()); + + // Create new join + RelNode newJoin = joinFactory.createJoin(sides.get(0).newInput, + sides.get(1).newInput, newCondition, join.getJoinType(), + join.getVariablesStopped(), join.isSemiJoinDone()); + + // Aggregate above to sum up the sub-totals + final List<AggregateCall> newAggCalls = new ArrayList<>(); + final int groupIndicatorCount = + aggregate.getGroupCount() + aggregate.getIndicatorCount(); + final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount(); + final List<RexNode> projects = + new ArrayList<>(rexBuilder.identityProjects(newJoin.getRowType())); + for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) { + final SqlAggFunction aggregation = aggCall.e.getAggregation(); + final SqlSplittableAggFunction splitter = + Preconditions.checkNotNull( + aggregation.unwrap(SqlSplittableAggFunction.class)); + final Integer leftSubTotal = sides.get(0).split.get(aggCall.i); + final Integer rightSubTotal = sides.get(1).split.get(aggCall.i); + newAggCalls.add( + splitter.topSplit(rexBuilder, registry(projects), + groupIndicatorCount, newJoin.getRowType(), aggCall.e, + leftSubTotal == null ? -1 : leftSubTotal, + rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth)); + } + RelNode r = newJoin; + b: + if (allColumnsInAggregate && newAggCalls.isEmpty() && + RelOptUtil.areRowTypesEqual(r.getRowType(), aggregate.getRowType(), false)) { + // no need to aggregate + } else { + r = RelOptUtil.createProject(r, projects, null, true, projectFactory); + if (allColumnsInAggregate) { + // let's see if we can convert + List<RexNode> projects2 = new ArrayList<>(); + for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) { + projects2.add(rexBuilder.makeInputRef(r, key)); + } + for (AggregateCall newAggCall : newAggCalls) { + final SqlSplittableAggFunction splitter = + newAggCall.getAggregation() + .unwrap(SqlSplittableAggFunction.class); + if (splitter != null) { + projects2.add( + splitter.singleton(rexBuilder, r.getRowType(), newAggCall)); + } + } + if (projects2.size() + == aggregate.getGroupSet().cardinality() + newAggCalls.size()) { + // We successfully converted agg calls into projects. + r = RelOptUtil.createProject(r, projects2, null, true, projectFactory); + break b; + } + } + r = aggregateFactory.createAggregate(r, aggregate.indicator, + Mappings.apply(mapping, aggregate.getGroupSet()), + Mappings.apply2(mapping, aggregate.getGroupSets()), newAggCalls); + } + call.transformTo(r); + } + + /** Computes the closure of a set of columns according to a given list of + * constraints. Each 'x = y' constraint causes bit y to be set if bit x is + * set, and vice versa. */ + private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, + ImmutableList<RexNode> predicates) { + SortedMap<Integer, BitSet> equivalence = new TreeMap<>(); + for (RexNode pred : predicates) { + populateEquivalences(equivalence, pred); + } + ImmutableBitSet keyColumns = aggregateColumns; + for (Integer aggregateColumn : aggregateColumns) { + final BitSet bitSet = equivalence.get(aggregateColumn); + if (bitSet != null) { + keyColumns = keyColumns.union(bitSet); + } + } + return keyColumns; + } + + private static void populateEquivalences(Map<Integer, BitSet> equivalence, + RexNode predicate) { + switch (predicate.getKind()) { + case EQUALS: + RexCall call = (RexCall) predicate; + final List<RexNode> operands = call.getOperands(); + if (operands.get(0) instanceof RexInputRef) { + final RexInputRef ref0 = (RexInputRef) operands.get(0); + if (operands.get(1) instanceof RexInputRef) { + final RexInputRef ref1 = (RexInputRef) operands.get(1); + populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex()); + populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex()); + } + } + } + } + + private static void populateEquivalence(Map<Integer, BitSet> equivalence, + int i0, int i1) { + BitSet bitSet = equivalence.get(i0); + if (bitSet == null) { + bitSet = new BitSet(); + equivalence.put(i0, bitSet); + } + bitSet.set(i1); + } + + /** Creates a {@link org.apache.calcite.sql.SqlSplittableAggFunction.Registry} + * that is a view of a list. */ + private static <E> SqlSplittableAggFunction.Registry<E> + registry(final List<E> list) { + return new SqlSplittableAggFunction.Registry<E>() { + @Override + public int register(E e) { + int i = list.indexOf(e); + if (i < 0) { + i = list.size(); + list.add(e); + } + return i; + } + }; + } + + /** Work space for an input to a join. */ + private static class Side { + final Map<Integer, Integer> split = new HashMap<>(); + RelNode newInput; + } +} + +// End AggregateJoinTransposeRule.java + http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java index fd78824..d59c6bb 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java @@ -45,6 +45,9 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException; import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException.UnsupportedFeature; +import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction; +import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlMinMaxAggFunction; +import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveBetween; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIn; import org.apache.hadoop.hive.ql.parse.ASTNode; @@ -310,6 +313,7 @@ public class SqlFunctionConverter { registerFunction("in", HiveIn.INSTANCE, hToken(HiveParser.Identifier, "in")); registerFunction("between", HiveBetween.INSTANCE, hToken(HiveParser.Identifier, "between")); registerFunction("struct", SqlStdOperatorTable.ROW, hToken(HiveParser.Identifier, "struct")); + } private void registerFunction(String name, SqlOperator calciteFn, HiveToken hiveToken) { @@ -339,8 +343,7 @@ public class SqlFunctionConverter { // UDAF is assumed to be deterministic public static class CalciteUDAF extends SqlAggFunction { public CalciteUDAF(String opName, SqlReturnTypeInference returnTypeInference, - SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker, - ImmutableList<RelDataType> argTypes, RelDataType retType) { + SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) { super(opName, SqlKind.OTHER_FUNCTION, returnTypeInference, operandTypeInference, operandTypeChecker, SqlFunctionCategory.USER_DEFINED_FUNCTION); } @@ -367,8 +370,6 @@ public class SqlFunctionConverter { private SqlReturnTypeInference returnTypeInference; private SqlOperandTypeInference operandTypeInference; private SqlOperandTypeChecker operandTypeChecker; - private ImmutableList<RelDataType> argTypes; - private RelDataType retType; } private static CalciteUDFInfo getUDFInfo(String hiveUdfName, @@ -382,10 +383,6 @@ public class SqlFunctionConverter { typeFamilyBuilder.add(Util.first(at.getSqlTypeName().getFamily(), SqlTypeFamily.ANY)); } udfInfo.operandTypeChecker = OperandTypes.family(typeFamilyBuilder.build()); - - udfInfo.argTypes = ImmutableList.<RelDataType> copyOf(calciteArgTypes); - udfInfo.retType = calciteRetType; - return udfInfo; } @@ -413,13 +410,34 @@ public class SqlFunctionConverter { public static SqlAggFunction getCalciteAggFn(String hiveUdfName, ImmutableList<RelDataType> calciteArgTypes, RelDataType calciteRetType) { SqlAggFunction calciteAggFn = (SqlAggFunction) hiveToCalcite.get(hiveUdfName); + if (calciteAggFn == null) { CalciteUDFInfo uInf = getUDFInfo(hiveUdfName, calciteArgTypes, calciteRetType); - calciteAggFn = new CalciteUDAF(uInf.udfName, uInf.returnTypeInference, - uInf.operandTypeInference, uInf.operandTypeChecker, uInf.argTypes, uInf.retType); - } + switch (hiveUdfName.toLowerCase()) { + case "sum": + calciteAggFn = new HiveSqlSumAggFunction(uInf.returnTypeInference, + uInf.operandTypeInference, uInf.operandTypeChecker); + break; + case "count": + calciteAggFn = new HiveSqlCountAggFunction(uInf.returnTypeInference, + uInf.operandTypeInference, uInf.operandTypeChecker); + break; + case "min": + calciteAggFn = new HiveSqlMinMaxAggFunction(uInf.returnTypeInference, + uInf.operandTypeInference, uInf.operandTypeChecker, true); + break; + case "max": + calciteAggFn = new HiveSqlMinMaxAggFunction(uInf.returnTypeInference, + uInf.operandTypeInference, uInf.operandTypeChecker, false); + break; + default: + calciteAggFn = new CalciteUDAF(uInf.udfName, uInf.returnTypeInference, + uInf.operandTypeInference, uInf.operandTypeChecker); + break; + } + } return calciteAggFn; } http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index 0a7ce3a..9c731b8 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -63,6 +63,7 @@ import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.metadata.CachingRelMetadataProvider; import org.apache.calcite.rel.metadata.ChainedRelMetadataProvider; import org.apache.calcite.rel.metadata.RelMetadataProvider; +import org.apache.calcite.rel.rules.AggregateJoinTransposeRule; import org.apache.calcite.rel.rules.FilterAggregateTransposeRule; import org.apache.calcite.rel.rules.FilterProjectTransposeRule; import org.apache.calcite.rel.rules.JoinToMultiJoinRule; @@ -134,6 +135,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSemiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortLimit; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion; +import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateJoinTransposeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateProjectMergeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExpandDistinctAggregatesRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveFilterJoinRule; @@ -885,6 +887,9 @@ public class CalcitePlanner extends SemanticAnalyzer { hepPgmBldr.addRuleInstance(UnionMergeRule.INSTANCE); hepPgmBldr.addRuleInstance(new ProjectMergeRule(false, HiveProject.DEFAULT_PROJECT_FACTORY)); hepPgmBldr.addRuleInstance(HiveAggregateProjectMergeRule.INSTANCE); + if (conf.getBoolVar(ConfVars.AGGR_JOIN_TRANSPOSE)) { + hepPgmBldr.addRuleInstance(HiveAggregateJoinTransposeRule.INSTANCE); + } hepPgm = hepPgmBldr.build(); HepPlanner hepPlanner = new HepPlanner(hepPgm); http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index 5a5846e..c6ffbec 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -356,7 +356,7 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { */ public static class GenericUDAFSumLong extends GenericUDAFEvaluator { private PrimitiveObjectInspector inputOI; - private LongWritable result; + protected LongWritable result; @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSumEmptyIsZero.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSumEmptyIsZero.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSumEmptyIsZero.java new file mode 100644 index 0000000..ab7ab04 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSumEmptyIsZero.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +@Description(name = "$SUM0", value = "_FUNC_(x) - Returns the sum of a set of numbers, zero if empty") +public class GenericUDAFSumEmptyIsZero extends GenericUDAFSum { + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) + throws SemanticException { + if (parameters.length != 1) { + throw new UDFArgumentTypeException(parameters.length - 1, + "Exactly one argument is expected."); + } + + if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "Only primitive type arguments are accepted but " + + parameters[0].getTypeName() + " is passed."); + } + switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) { + case LONG: + return new SumZeroIfEmpty(); + default: + throw new UDFArgumentTypeException(0, + "Only bigint type arguments are accepted but " + + parameters[0].getTypeName() + " is passed."); + } + } + + public static class SumZeroIfEmpty extends GenericUDAFSumLong { + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + SumLongAgg myagg = (SumLongAgg) agg; + result.set(myagg.sum); + return result; + } + } +} + http://git-wip-us.apache.org/repos/asf/hive/blob/68d6cfda/ql/src/test/queries/clientpositive/groupby_join_pushdown.q ---------------------------------------------------------------------- diff --git a/ql/src/test/queries/clientpositive/groupby_join_pushdown.q b/ql/src/test/queries/clientpositive/groupby_join_pushdown.q new file mode 100644 index 0000000..bf1ae4b --- /dev/null +++ b/ql/src/test/queries/clientpositive/groupby_join_pushdown.q @@ -0,0 +1,55 @@ +set hive.transpose.aggr.join=true; +EXPLAIN +SELECT f.key, g.key, count(g.key) +FROM src f JOIN src g ON(f.key = g.key) +GROUP BY f.key, g.key; + +EXPLAIN +SELECT f.key, g.key +FROM src f JOIN src g ON(f.key = g.key) +GROUP BY f.key, g.key; + +EXPLAIN +SELECT DISTINCT f.value, g.value +FROM src f JOIN src g ON(f.value = g.value); + +EXPLAIN +SELECT f.key, g.key, COUNT(*) +FROM src f JOIN src g ON(f.key = g.key) +GROUP BY f.key, g.key; + +EXPLAIN +SELECT f.ctinyint, g.ctinyint, SUM(f.cbigint) +FROM alltypesorc f JOIN alltypesorc g ON(f.cint = g.cint) +GROUP BY f.ctinyint, g.ctinyint ; + +EXPLAIN +SELECT f.cbigint, g.cbigint, MAX(f.cint) +FROM alltypesorc f JOIN alltypesorc g ON(f.cbigint = g.cbigint) +GROUP BY f.cbigint, g.cbigint ; + +explain +SELECT f.ctinyint, g.ctinyint, MIN(f.ctinyint) +FROM alltypesorc f JOIN alltypesorc g ON(f.ctinyint = g.ctinyint) +GROUP BY f.ctinyint, g.ctinyint; + +explain +SELECT MIN(f.cint) +FROM alltypesorc f JOIN alltypesorc g ON(f.ctinyint = g.ctinyint) +GROUP BY f.ctinyint, g.ctinyint; + +explain +SELECT count(f.ctinyint) +FROM alltypesorc f JOIN alltypesorc g ON(f.ctinyint = g.ctinyint) +GROUP BY f.ctinyint, g.ctinyint; + +explain +SELECT count(f.cint), f.ctinyint +FROM alltypesorc f JOIN alltypesorc g ON(f.ctinyint = g.ctinyint) +GROUP BY f.ctinyint, g.ctinyint; + +explain +SELECT sum(f.cint), f.ctinyint +FROM alltypesorc f JOIN alltypesorc g ON(f.ctinyint = g.ctinyint) +GROUP BY f.ctinyint, g.ctinyint; +
