HIVE-10533 : CBO (Calcite Return Path): Join to MultiJoin support for outer joins (Jesus Camacho Rodriguez via Ashutosh Chauhan)
Signed-off-by: Ashutosh Chauhan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/hive/repo Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/6d19df3a Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/6d19df3a Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/6d19df3a Branch: refs/heads/beeline-cli Commit: 6d19df3aceb26e182d9e0340ab6df44665c042f6 Parents: 11020ae Author: Jesus Camacho Rodriguez <[email protected]> Authored: Wed Jun 24 10:46:00 2015 -0700 Committer: Ashutosh Chauhan <[email protected]> Committed: Wed Jun 24 18:20:05 2015 -0700 ---------------------------------------------------------------------- .../ql/optimizer/calcite/HiveCalciteUtil.java | 129 +- .../ql/optimizer/calcite/HiveRelOptUtil.java | 351 + .../calcite/reloperators/HiveMultiJoin.java | 198 + .../rules/HiveInsertExchange4JoinRule.java | 27 +- .../rules/HiveJoinProjectTransposeRule.java | 60 + .../calcite/rules/HiveJoinToMultiJoinRule.java | 309 +- .../calcite/rules/HiveProjectMergeRule.java | 1 - .../calcite/rules/HiveRelFieldTrimmer.java | 106 + .../calcite/translator/HiveOpConverter.java | 185 +- .../hadoop/hive/ql/parse/CalcitePlanner.java | 22 +- .../apache/hadoop/hive/ql/plan/JoinDesc.java | 7 - .../test/queries/clientpositive/cbo_rp_join0.q | 26 + .../test/queries/clientpositive/cbo_rp_join1.q | 22 + .../results/clientpositive/cbo_rp_join0.q.out | 6867 ++++++++++++++++++ .../results/clientpositive/cbo_rp_join1.q.out | 426 ++ 15 files changed, 8429 insertions(+), 307 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java index 199a358..024097e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java @@ -26,12 +26,12 @@ import java.util.Map.Entry; import java.util.Set; import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelOptUtil.InputFinder; import org.apache.calcite.plan.RelOptUtil.InputReferencedVisitor; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.RelFactories.ProjectFactory; import org.apache.calcite.rel.core.Sort; -import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; @@ -53,8 +53,11 @@ import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.metadata.VirtualColumn; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ExprNodeConverter; import org.apache.hadoop.hive.ql.parse.ASTNode; @@ -76,6 +79,9 @@ import com.google.common.collect.Sets; public class HiveCalciteUtil { + private static final Log LOG = LogFactory.getLog(HiveCalciteUtil.class); + + /** * Get list of virtual columns from the given list of projections. * <p> @@ -336,16 +342,16 @@ public class HiveCalciteUtil { public static JoinPredicateInfo constructJoinPredicateInfo(Join j) { return constructJoinPredicateInfo(j, j.getCondition()); } - - public static JoinPredicateInfo constructJoinPredicateInfo(MultiJoin mj) { - return constructJoinPredicateInfo(mj, mj.getJoinFilter()); + + public static JoinPredicateInfo constructJoinPredicateInfo(HiveMultiJoin mj) { + return constructJoinPredicateInfo(mj, mj.getCondition()); } public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predicate) { return constructJoinPredicateInfo(j.getInputs(), j.getSystemFieldList(), predicate); } - public static JoinPredicateInfo constructJoinPredicateInfo(MultiJoin mj, RexNode predicate) { + public static JoinPredicateInfo constructJoinPredicateInfo(HiveMultiJoin mj, RexNode predicate) { final List<RelDataTypeField> systemFieldList = ImmutableList.of(); return constructJoinPredicateInfo(mj.getInputs(), systemFieldList, predicate); } @@ -383,24 +389,24 @@ public class HiveCalciteUtil { // 2.2 Classify leaf predicate as Equi vs Non Equi if (jlpi.comparisonType.equals(SqlKind.EQUALS)) { equiLPIList.add(jlpi); - } else { - nonEquiLPIList.add(jlpi); - } - - // 2.3 Maintain join keys (in child & Join Schema) - // 2.4 Update Join Key to JoinLeafPredicateInfo map with keys - for (int i=0; i<inputs.size(); i++) { - projsJoinKeys.get(i).addAll(jlpi.getProjsJoinKeysInChildSchema(i)); - projsJoinKeysInJoinSchema.get(i).addAll(jlpi.getProjsJoinKeysInJoinSchema(i)); - for (Integer projIndx : jlpi.getProjsJoinKeysInJoinSchema(i)) { - tmpJLPILst = tmpMapOfProjIndxInJoinSchemaToLeafPInfo.get(projIndx); - if (tmpJLPILst == null) { - tmpJLPILst = new ArrayList<JoinLeafPredicateInfo>(); + // 2.2.1 Maintain join keys (in child & Join Schema) + // 2.2.2 Update Join Key to JoinLeafPredicateInfo map with keys + for (int i=0; i<inputs.size(); i++) { + projsJoinKeys.get(i).addAll(jlpi.getProjsJoinKeysInChildSchema(i)); + projsJoinKeysInJoinSchema.get(i).addAll(jlpi.getProjsJoinKeysInJoinSchema(i)); + + for (Integer projIndx : jlpi.getProjsJoinKeysInJoinSchema(i)) { + tmpJLPILst = tmpMapOfProjIndxInJoinSchemaToLeafPInfo.get(projIndx); + if (tmpJLPILst == null) { + tmpJLPILst = new ArrayList<JoinLeafPredicateInfo>(); + } + tmpJLPILst.add(jlpi); + tmpMapOfProjIndxInJoinSchemaToLeafPInfo.put(projIndx, tmpJLPILst); } - tmpJLPILst.add(jlpi); - tmpMapOfProjIndxInJoinSchemaToLeafPInfo.put(projIndx, tmpJLPILst); } + } else { + nonEquiLPIList.add(jlpi); } } @@ -448,21 +454,21 @@ public class HiveCalciteUtil { this.joinKeyExprs = joinKeyExprsBuilder.build(); ImmutableList.Builder<ImmutableSet<Integer>> projsJoinKeysInChildSchemaBuilder = ImmutableList.builder(); - for (int i=0; i<joinKeyExprs.size(); i++) { + for (int i=0; i<projsJoinKeysInChildSchema.size(); i++) { projsJoinKeysInChildSchemaBuilder.add( ImmutableSet.copyOf(projsJoinKeysInChildSchema.get(i))); } this.projsJoinKeysInChildSchema = projsJoinKeysInChildSchemaBuilder.build(); ImmutableList.Builder<ImmutableSet<Integer>> projsJoinKeysInJoinSchemaBuilder = ImmutableList.builder(); - for (int i=0; i<joinKeyExprs.size(); i++) { + for (int i=0; i<projsJoinKeysInJoinSchema.size(); i++) { projsJoinKeysInJoinSchemaBuilder.add( ImmutableSet.copyOf(projsJoinKeysInJoinSchema.get(i))); } this.projsJoinKeysInJoinSchema = projsJoinKeysInJoinSchemaBuilder.build(); } - public List<RexNode> getJoinKeyExprs(int input) { + public List<RexNode> getJoinExprs(int input) { return this.joinKeyExprs.get(input); } @@ -494,48 +500,67 @@ public class HiveCalciteUtil { return this.projsJoinKeysInJoinSchema.get(input); } + // We create the join predicate info object. The object contains the join condition, + // split accordingly. If the join condition is not part of the equi-join predicate, + // the returned object will be typed as SQLKind.OTHER. private static JoinLeafPredicateInfo constructJoinLeafPredicateInfo(List<RelNode> inputs, List<RelDataTypeField> systemFieldList, RexNode pe) { JoinLeafPredicateInfo jlpi = null; List<Integer> filterNulls = new ArrayList<Integer>(); - List<List<RexNode>> joinKeyExprs = new ArrayList<List<RexNode>>(); + List<List<RexNode>> joinExprs = new ArrayList<List<RexNode>>(); for (int i=0; i<inputs.size(); i++) { - joinKeyExprs.add(new ArrayList<RexNode>()); + joinExprs.add(new ArrayList<RexNode>()); } // 1. Split leaf join predicate to expressions from left, right - RelOptUtil.splitJoinCondition(systemFieldList, inputs, pe, - joinKeyExprs, filterNulls, null); + RexNode otherConditions = HiveRelOptUtil.splitJoinCondition(systemFieldList, inputs, pe, + joinExprs, filterNulls, null); - // 2. Collect child projection indexes used - List<Set<Integer>> projsJoinKeysInChildSchema = - new ArrayList<Set<Integer>>(); - for (int i=0; i<inputs.size(); i++) { - ImmutableSet.Builder<Integer> projsFromInputJoinKeysInChildSchema = ImmutableSet.builder(); - InputReferencedVisitor irvLeft = new InputReferencedVisitor(); - irvLeft.apply(joinKeyExprs.get(i)); - projsFromInputJoinKeysInChildSchema.addAll(irvLeft.inputPosReferenced); - projsJoinKeysInChildSchema.add(projsFromInputJoinKeysInChildSchema.build()); - } + if (otherConditions.isAlwaysTrue()) { + // 2. Collect child projection indexes used + List<Set<Integer>> projsJoinKeysInChildSchema = + new ArrayList<Set<Integer>>(); + for (int i=0; i<inputs.size(); i++) { + ImmutableSet.Builder<Integer> projsFromInputJoinKeysInChildSchema = ImmutableSet.builder(); + InputReferencedVisitor irvLeft = new InputReferencedVisitor(); + irvLeft.apply(joinExprs.get(i)); + projsFromInputJoinKeysInChildSchema.addAll(irvLeft.inputPosReferenced); + projsJoinKeysInChildSchema.add(projsFromInputJoinKeysInChildSchema.build()); + } + + // 3. Translate projection indexes to join schema, by adding offset. + List<Set<Integer>> projsJoinKeysInJoinSchema = + new ArrayList<Set<Integer>>(); + // The offset of the first input does not need to change. + projsJoinKeysInJoinSchema.add(projsJoinKeysInChildSchema.get(0)); + for (int i=1; i<inputs.size(); i++) { + int offSet = inputs.get(i-1).getRowType().getFieldCount(); + ImmutableSet.Builder<Integer> projsFromInputJoinKeysInJoinSchema = ImmutableSet.builder(); + for (Integer indx : projsJoinKeysInChildSchema.get(i)) { + projsFromInputJoinKeysInJoinSchema.add(indx + offSet); + } + projsJoinKeysInJoinSchema.add(projsFromInputJoinKeysInJoinSchema.build()); + } - // 3. Translate projection indexes to join schema, by adding offset. - List<Set<Integer>> projsJoinKeysInJoinSchema = - new ArrayList<Set<Integer>>(); - // The offset of the first input does not need to change. - projsJoinKeysInJoinSchema.add(projsJoinKeysInChildSchema.get(0)); - for (int i=1; i<inputs.size(); i++) { - int offSet = inputs.get(i-1).getRowType().getFieldCount(); - ImmutableSet.Builder<Integer> projsFromInputJoinKeysInJoinSchema = ImmutableSet.builder(); - for (Integer indx : projsJoinKeysInChildSchema.get(i)) { - projsFromInputJoinKeysInJoinSchema.add(indx + offSet); + // 4. Construct JoinLeafPredicateInfo + jlpi = new JoinLeafPredicateInfo(pe.getKind(), joinExprs, + projsJoinKeysInChildSchema, projsJoinKeysInJoinSchema); + } else { + // 2. Construct JoinLeafPredicateInfo + ImmutableBitSet refCols = InputFinder.bits(pe); + int count = 0; + for (int i=0; i<inputs.size(); i++) { + final int length = inputs.get(i).getRowType().getFieldCount(); + ImmutableBitSet inputRange = ImmutableBitSet.range(count, count + length); + if (inputRange.contains(refCols)) { + joinExprs.get(i).add(pe); + } + count += length; } - projsJoinKeysInJoinSchema.add(projsFromInputJoinKeysInJoinSchema.build()); + jlpi = new JoinLeafPredicateInfo(SqlKind.OTHER, joinExprs, + new ArrayList<Set<Integer>>(), new ArrayList<Set<Integer>>()); } - // 4. Construct JoinLeafPredicateInfo - jlpi = new JoinLeafPredicateInfo(pe.getKind(), joinKeyExprs, - projsJoinKeysInChildSchema, projsJoinKeysInJoinSchema); - return jlpi; } } http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java new file mode 100644 index 0000000..9ebb24f --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java @@ -0,0 +1,351 @@ +package org.apache.hadoop.hive.ql.optimizer.calcite; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Util; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import com.google.common.collect.ImmutableList; + +public class HiveRelOptUtil extends RelOptUtil { + + private static final Log LOG = LogFactory.getLog(HiveRelOptUtil.class); + + + /** + * Splits out the equi-join (and optionally, a single non-equi) components + * of a join condition, and returns what's left. Projection might be + * required by the caller to provide join keys that are not direct field + * references. + * + * @param sysFieldList list of system fields + * @param inputs join inputs + * @param condition join condition + * @param joinKeys The join keys from the inputs which are equi-join + * keys + * @param filterNulls The join key positions for which null values will not + * match. null values only match for the "is not distinct + * from" condition. + * @param rangeOp if null, only locate equi-joins; otherwise, locate a + * single non-equi join predicate and return its operator + * in this list; join keys associated with the non-equi + * join predicate are at the end of the key lists + * returned + * @return What's left, never null + */ + public static RexNode splitJoinCondition( + List<RelDataTypeField> sysFieldList, + List<RelNode> inputs, + RexNode condition, + List<List<RexNode>> joinKeys, + List<Integer> filterNulls, + List<SqlOperator> rangeOp) { + final List<RexNode> nonEquiList = new ArrayList<>(); + + splitJoinCondition( + sysFieldList, + inputs, + condition, + joinKeys, + filterNulls, + rangeOp, + nonEquiList); + + // Convert the remainders into a list that are AND'ed together. + return RexUtil.composeConjunction( + inputs.get(0).getCluster().getRexBuilder(), nonEquiList, false); + } + + private static void splitJoinCondition( + List<RelDataTypeField> sysFieldList, + List<RelNode> inputs, + RexNode condition, + List<List<RexNode>> joinKeys, + List<Integer> filterNulls, + List<SqlOperator> rangeOp, + List<RexNode> nonEquiList) { + final int sysFieldCount = sysFieldList.size(); + final RelOptCluster cluster = inputs.get(0).getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + + final ImmutableBitSet[] inputsRange = new ImmutableBitSet[inputs.size()]; + int totalFieldCount = 0; + for (int i = 0; i < inputs.size(); i++) { + final int firstField = totalFieldCount + sysFieldCount; + totalFieldCount = firstField + inputs.get(i).getRowType().getFieldCount(); + inputsRange[i] = ImmutableBitSet.range(firstField, totalFieldCount); + } + + // adjustment array + int[] adjustments = new int[totalFieldCount]; + for (int i = 0; i < inputs.size(); i++) { + final int adjustment = inputsRange[i].nextSetBit(0); + for (int j = adjustment; j < inputsRange[i].length(); j++) { + adjustments[j] = -adjustment; + } + } + + if (condition instanceof RexCall) { + RexCall call = (RexCall) condition; + if (call.getOperator() == SqlStdOperatorTable.AND) { + for (RexNode operand : call.getOperands()) { + splitJoinCondition( + sysFieldList, + inputs, + operand, + joinKeys, + filterNulls, + rangeOp, + nonEquiList); + } + return; + } + + RexNode leftKey = null; + RexNode rightKey = null; + int leftInput = 0; + int rightInput = 0; + List<RelDataTypeField> leftFields = null; + List<RelDataTypeField> rightFields = null; + boolean reverse = false; + + SqlKind kind = call.getKind(); + + // Only consider range operators if we haven't already seen one + if ((kind == SqlKind.EQUALS) + || (filterNulls != null + && kind == SqlKind.IS_NOT_DISTINCT_FROM) + || (rangeOp != null + && rangeOp.isEmpty() + && (kind == SqlKind.GREATER_THAN + || kind == SqlKind.GREATER_THAN_OR_EQUAL + || kind == SqlKind.LESS_THAN + || kind == SqlKind.LESS_THAN_OR_EQUAL))) { + final List<RexNode> operands = call.getOperands(); + RexNode op0 = operands.get(0); + RexNode op1 = operands.get(1); + + final ImmutableBitSet projRefs0 = InputFinder.bits(op0); + final ImmutableBitSet projRefs1 = InputFinder.bits(op1); + + boolean foundBothInputs = false; + for (int i = 0; i < inputs.size() && !foundBothInputs; i++) { + if (projRefs0.intersects(inputsRange[i]) + && projRefs0.union(inputsRange[i]).equals(inputsRange[i])) { + if (leftKey == null) { + leftKey = op0; + leftInput = i; + leftFields = inputs.get(leftInput).getRowType().getFieldList(); + } else { + rightKey = op0; + rightInput = i; + rightFields = inputs.get(rightInput).getRowType().getFieldList(); + reverse = true; + foundBothInputs = true; + } + } else if (projRefs1.intersects(inputsRange[i]) + && projRefs1.union(inputsRange[i]).equals(inputsRange[i])) { + if (leftKey == null) { + leftKey = op1; + leftInput = i; + leftFields = inputs.get(leftInput).getRowType().getFieldList(); + } else { + rightKey = op1; + rightInput = i; + rightFields = inputs.get(rightInput).getRowType().getFieldList(); + foundBothInputs = true; + } + } + } + + if ((leftKey != null) && (rightKey != null)) { + // replace right Key input ref + rightKey = + rightKey.accept( + new RelOptUtil.RexInputConverter( + rexBuilder, + rightFields, + rightFields, + adjustments)); + + // left key only needs to be adjusted if there are system + // fields, but do it for uniformity + leftKey = + leftKey.accept( + new RelOptUtil.RexInputConverter( + rexBuilder, + leftFields, + leftFields, + adjustments)); + + RelDataType leftKeyType = leftKey.getType(); + RelDataType rightKeyType = rightKey.getType(); + + if (leftKeyType != rightKeyType) { + // perform casting + RelDataType targetKeyType = + typeFactory.leastRestrictive( + ImmutableList.of(leftKeyType, rightKeyType)); + + if (targetKeyType == null) { + throw Util.newInternal( + "Cannot find common type for join keys " + + leftKey + " (type " + leftKeyType + ") and " + + rightKey + " (type " + rightKeyType + ")"); + } + + if (leftKeyType != targetKeyType) { + leftKey = + rexBuilder.makeCast(targetKeyType, leftKey); + } + + if (rightKeyType != targetKeyType) { + rightKey = + rexBuilder.makeCast(targetKeyType, rightKey); + } + } + } + } + +// if ((rangeOp == null) +// && ((leftKey == null) || (rightKey == null))) { +// // no equality join keys found yet: +// // try transforming the condition to +// // equality "join" conditions, e.g. +// // f(LHS) > 0 ===> ( f(LHS) > 0 ) = TRUE, +// // and make the RHS produce TRUE, but only if we're strictly +// // looking for equi-joins +// final ImmutableBitSet projRefs = InputFinder.bits(condition); +// leftKey = null; +// rightKey = null; +// +// boolean foundInput = false; +// for (int i = 0; i < inputs.size() && !foundInput; i++) { +// final int lowerLimit = inputsRange[i].nextSetBit(0); +// final int upperLimit = inputsRange[i].length(); +// if (projRefs.nextSetBit(lowerLimit) < upperLimit) { +// leftInput = i; +// leftFields = inputs.get(leftInput).getRowType().getFieldList(); +// +// leftKey = condition.accept( +// new RelOptUtil.RexInputConverter( +// rexBuilder, +// leftFields, +// leftFields, +// adjustments)); +// +// rightKey = rexBuilder.makeLiteral(true); +// +// // effectively performing an equality comparison +// kind = SqlKind.EQUALS; +// +// foundInput = true; +// } +// } +// } + + if ((leftKey != null) && (rightKey != null)) { + // found suitable join keys + // add them to key list, ensuring that if there is a + // non-equi join predicate, it appears at the end of the + // key list; also mark the null filtering property + addJoinKey( + joinKeys.get(leftInput), + leftKey, + (rangeOp != null) && !rangeOp.isEmpty()); + addJoinKey( + joinKeys.get(rightInput), + rightKey, + (rangeOp != null) && !rangeOp.isEmpty()); + if (filterNulls != null + && kind == SqlKind.EQUALS) { + // nulls are considered not matching for equality comparison + // add the position of the most recently inserted key + filterNulls.add(joinKeys.get(leftInput).size() - 1); + } + if (rangeOp != null + && kind != SqlKind.EQUALS + && kind != SqlKind.IS_DISTINCT_FROM) { + if (reverse) { + kind = reverse(kind); + } + rangeOp.add(op(kind, call.getOperator())); + } + return; + } // else fall through and add this condition as nonEqui condition + } + + // The operator is not of RexCall type + // So we fail. Fall through. + // Add this condition to the list of non-equi-join conditions. + nonEquiList.add(condition); + } + + private static SqlKind reverse(SqlKind kind) { + switch (kind) { + case GREATER_THAN: + return SqlKind.LESS_THAN; + case GREATER_THAN_OR_EQUAL: + return SqlKind.LESS_THAN_OR_EQUAL; + case LESS_THAN: + return SqlKind.GREATER_THAN; + case LESS_THAN_OR_EQUAL: + return SqlKind.GREATER_THAN_OR_EQUAL; + default: + return kind; + } + } + + private static SqlOperator op(SqlKind kind, SqlOperator operator) { + switch (kind) { + case EQUALS: + return SqlStdOperatorTable.EQUALS; + case NOT_EQUALS: + return SqlStdOperatorTable.NOT_EQUALS; + case GREATER_THAN: + return SqlStdOperatorTable.GREATER_THAN; + case GREATER_THAN_OR_EQUAL: + return SqlStdOperatorTable.GREATER_THAN_OR_EQUAL; + case LESS_THAN: + return SqlStdOperatorTable.LESS_THAN; + case LESS_THAN_OR_EQUAL: + return SqlStdOperatorTable.LESS_THAN_OR_EQUAL; + case IS_DISTINCT_FROM: + return SqlStdOperatorTable.IS_DISTINCT_FROM; + case IS_NOT_DISTINCT_FROM: + return SqlStdOperatorTable.IS_NOT_DISTINCT_FROM; + default: + return operator; + } + } + + private static void addJoinKey( + List<RexNode> joinKeyList, + RexNode key, + boolean preserveLastElementInList) { + if (!joinKeyList.isEmpty() && preserveLastElementInList) { + joinKeyList.add(joinKeyList.size() - 1, key); + } else { + joinKeyList.add(key); + } + } + + +} http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java new file mode 100644 index 0000000..911ceda --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java @@ -0,0 +1,198 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.reloperators; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.AbstractRelNode; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.util.Pair; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; +import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * A HiveMultiJoin represents a succession of binary joins. + */ +public final class HiveMultiJoin extends AbstractRelNode { + + private final List<RelNode> inputs; + private final RexNode condition; + private final RelDataType rowType; + private final ImmutableList<Pair<Integer,Integer>> joinInputs; + private final ImmutableList<JoinRelType> joinTypes; + + private final boolean outerJoin; + private final JoinPredicateInfo joinPredInfo; + + + /** + * Constructs a MultiJoin. + * + * @param cluster cluster that join belongs to + * @param inputs inputs into this multi-join + * @param condition join filter applicable to this join node + * @param rowType row type of the join result of this node + * @param joinInputs + * @param joinTypes the join type corresponding to each input; if + * an input is null-generating in a left or right + * outer join, the entry indicates the type of + * outer join; otherwise, the entry is set to + * INNER + */ + public HiveMultiJoin( + RelOptCluster cluster, + List<RelNode> inputs, + RexNode joinFilter, + RelDataType rowType, + List<Pair<Integer,Integer>> joinInputs, + List<JoinRelType> joinTypes) { + super(cluster, TraitsUtil.getDefaultTraitSet(cluster)); + this.inputs = Lists.newArrayList(inputs); + this.condition = joinFilter; + this.rowType = rowType; + + assert joinInputs.size() == joinTypes.size(); + this.joinInputs = ImmutableList.copyOf(joinInputs); + this.joinTypes = ImmutableList.copyOf(joinTypes); + this.outerJoin = containsOuter(); + + this.joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(this); + } + + + @Override + public void replaceInput(int ordinalInParent, RelNode p) { + inputs.set(ordinalInParent, p); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) { + assert traitSet.containsIfApplicable(HiveRelNode.CONVENTION); + return new HiveMultiJoin( + getCluster(), + inputs, + condition, + rowType, + joinInputs, + joinTypes); + } + + public RelWriter explainTerms(RelWriter pw) { + List<String> joinsString = new ArrayList<String>(); + for (int i = 0; i < joinInputs.size(); i++) { + final StringBuilder sb = new StringBuilder(); + sb.append(joinInputs.get(i).left).append(" - ").append(joinInputs.get(i).right) + .append(" : ").append(joinTypes.get(i).name()); + joinsString.add(sb.toString()); + } + + super.explainTerms(pw); + for (Ord<RelNode> ord : Ord.zip(inputs)) { + pw.input("input#" + ord.i, ord.e); + } + return pw.item("condition", condition) + .item("joinsDescription", joinsString); + } + + public RelDataType deriveRowType() { + return rowType; + } + + public List<RelNode> getInputs() { + return inputs; + } + + @Override public List<RexNode> getChildExps() { + return ImmutableList.of(condition); + } + + public RelNode accept(RexShuttle shuttle) { + RexNode joinFilter = shuttle.apply(this.condition); + + if (joinFilter == this.condition) { + return this; + } + + return new HiveMultiJoin( + getCluster(), + inputs, + joinFilter, + rowType, + joinInputs, + joinTypes); + } + + /** + * @return join filters associated with this MultiJoin + */ + public RexNode getCondition() { + return condition; + } + + /** + * @return true if the MultiJoin contains a (partial) outer join. + */ + public boolean isOuterJoin() { + return outerJoin; + } + + /** + * @return join relationships between inputs + */ + public List<Pair<Integer,Integer>> getJoinInputs() { + return joinInputs; + } + + /** + * @return join types of each input + */ + public List<JoinRelType> getJoinTypes() { + return joinTypes; + } + + /** + * @return the join predicate information + */ + public JoinPredicateInfo getJoinPredicateInfo() { + return joinPredInfo; + } + + private boolean containsOuter() { + for (JoinRelType joinType : joinTypes) { + if (joinType != JoinRelType.INNER) { + return true; + } + } + return false; + } +} + +// End MultiJoin.java http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index 11c3d23..c5ab055 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; import java.util.ArrayList; import java.util.List; +import java.util.Set; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -28,7 +29,6 @@ import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Exchange; import org.apache.calcite.rel.core.Join; -import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rex.RexNode; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -37,9 +37,11 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinLeafPredi import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelCollation; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelDistribution; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; /** Not an optimization rule. * Rule to aid in translation from Calcite tree -> Hive tree. @@ -57,7 +59,7 @@ public class HiveInsertExchange4JoinRule extends RelOptRule { /** Rule that creates Exchange operators under a MultiJoin operator. */ public static final HiveInsertExchange4JoinRule EXCHANGE_BELOW_MULTIJOIN = - new HiveInsertExchange4JoinRule(MultiJoin.class); + new HiveInsertExchange4JoinRule(HiveMultiJoin.class); /** Rule that creates Exchange operators under a Join operator. */ public static final HiveInsertExchange4JoinRule EXCHANGE_BELOW_JOIN = @@ -71,8 +73,8 @@ public class HiveInsertExchange4JoinRule extends RelOptRule { @Override public void onMatch(RelOptRuleCall call) { JoinPredicateInfo joinPredInfo; - if (call.rel(0) instanceof MultiJoin) { - MultiJoin multiJoin = call.rel(0); + if (call.rel(0) instanceof HiveMultiJoin) { + HiveMultiJoin multiJoin = call.rel(0); joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(multiJoin); } else if (call.rel(0) instanceof Join) { Join join = call.rel(0); @@ -87,18 +89,23 @@ public class HiveInsertExchange4JoinRule extends RelOptRule { } } - // get key columns from inputs. Those are the columns on which we will distribute on. + // Get key columns from inputs. Those are the columns on which we will distribute on. // It is also the columns we will sort on. List<RelNode> newInputs = new ArrayList<RelNode>(); for (int i=0; i<call.rel(0).getInputs().size(); i++) { List<Integer> joinKeyPositions = new ArrayList<Integer>(); - ImmutableList.Builder<RexNode> keyListBuilder = new ImmutableList.Builder<RexNode>(); + ImmutableList.Builder<RexNode> joinExprsBuilder = new ImmutableList.Builder<RexNode>(); + Set<String> keySet = Sets.newHashSet(); ImmutableList.Builder<RelFieldCollation> collationListBuilder = new ImmutableList.Builder<RelFieldCollation>(); for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo. getEquiJoinPredicateElements().get(j); - keyListBuilder.add(joinLeafPredInfo.getJoinKeyExprs(i).get(0)); + for (RexNode joinExprNode : joinLeafPredInfo.getJoinExprs(i)) { + if (keySet.add(joinExprNode.toString())) { + joinExprsBuilder.add(joinExprNode); + } + } for (int pos : joinLeafPredInfo.getProjsJoinKeysInChildSchema(i)) { if (!joinKeyPositions.contains(pos)) { joinKeyPositions.add(pos); @@ -109,13 +116,13 @@ public class HiveInsertExchange4JoinRule extends RelOptRule { HiveSortExchange exchange = HiveSortExchange.create(call.rel(0).getInput(i), new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinKeyPositions), new HiveRelCollation(collationListBuilder.build()), - keyListBuilder.build()); + joinExprsBuilder.build()); newInputs.add(exchange); } RelNode newOp; - if (call.rel(0) instanceof MultiJoin) { - MultiJoin multiJoin = call.rel(0); + if (call.rel(0) instanceof HiveMultiJoin) { + HiveMultiJoin multiJoin = call.rel(0); newOp = multiJoin.copy(multiJoin.getTraitSet(), newInputs); } else if (call.rel(0) instanceof Join) { Join join = call.rel(0); http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinProjectTransposeRule.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinProjectTransposeRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinProjectTransposeRule.java new file mode 100644 index 0000000..40bf043 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinProjectTransposeRule.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import org.apache.calcite.plan.RelOptRuleOperand; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.rules.JoinProjectTransposeRule; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; + +public class HiveJoinProjectTransposeRule extends JoinProjectTransposeRule { + + public static final HiveJoinProjectTransposeRule BOTH_PROJECT = + new HiveJoinProjectTransposeRule( + operand(HiveJoin.class, + operand(HiveProject.class, any()), + operand(HiveProject.class, any())), + "JoinProjectTransposeRule(Project-Project)", + HiveProject.DEFAULT_PROJECT_FACTORY); + + public static final HiveJoinProjectTransposeRule LEFT_PROJECT = + new HiveJoinProjectTransposeRule( + operand(HiveJoin.class, + some(operand(HiveProject.class, any()))), + "JoinProjectTransposeRule(Project-Other)", + HiveProject.DEFAULT_PROJECT_FACTORY); + + public static final HiveJoinProjectTransposeRule RIGHT_PROJECT = + new HiveJoinProjectTransposeRule( + operand( + HiveJoin.class, + operand(RelNode.class, any()), + operand(HiveProject.class, any())), + "JoinProjectTransposeRule(Other-Project)", + HiveProject.DEFAULT_PROJECT_FACTORY); + + + private HiveJoinProjectTransposeRule( + RelOptRuleOperand operand, + String description, ProjectFactory pFactory) { + super(operand, description, pFactory); + } + +} http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java index 532d7d3..c5e0e11 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -17,8 +17,8 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; +import java.util.ArrayList; import java.util.List; -import java.util.Map; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -26,21 +26,24 @@ import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.rules.JoinCommuteRule; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; -import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.util.ImmutableBitSet; -import org.apache.calcite.util.ImmutableIntList; import org.apache.calcite.util.Pair; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; /** * Rule that merges a join with multijoin/join children if @@ -49,130 +52,172 @@ import com.google.common.collect.Maps; public class HiveJoinToMultiJoinRule extends RelOptRule { public static final HiveJoinToMultiJoinRule INSTANCE = - new HiveJoinToMultiJoinRule(Join.class); + new HiveJoinToMultiJoinRule(HiveJoin.class, HiveProject.DEFAULT_PROJECT_FACTORY); + + private final ProjectFactory projectFactory; + //~ Constructors ----------------------------------------------------------- /** * Creates a JoinToMultiJoinRule. */ - public HiveJoinToMultiJoinRule(Class<? extends Join> clazz) { - super( - operand(clazz, - operand(RelNode.class, any()), - operand(RelNode.class, any()))); + public HiveJoinToMultiJoinRule(Class<? extends Join> clazz, ProjectFactory projectFactory) { + super(operand(clazz, + operand(RelNode.class, any()), + operand(RelNode.class, any()))); + this.projectFactory = projectFactory; } //~ Methods ---------------------------------------------------------------- @Override public void onMatch(RelOptRuleCall call) { - final Join join = call.rel(0); + final HiveJoin join = call.rel(0); final RelNode left = call.rel(1); final RelNode right = call.rel(2); - final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + // 1. We try to merge this join with the left child + RelNode multiJoin = mergeJoin(join, left, right); + if (multiJoin != null) { + call.transformTo(multiJoin); + return; + } - // We do not merge outer joins currently - if (join.getJoinType() != JoinRelType.INNER) { + // 2. If we cannot, we swap the inputs so we can try + // to merge it with its right child + RelNode swapped = JoinCommuteRule.swap(join, true); + assert swapped != null; + + // The result of the swapping operation is either + // i) a Project or, + // ii) if the project is trivial, a raw join + final Join newJoin; + Project topProject = null; + if (swapped instanceof Join) { + newJoin = (Join) swapped; + } else { + topProject = (Project) swapped; + newJoin = (Join) swapped.getInput(0); + } + + // 3. We try to merge the join with the right child + multiJoin = mergeJoin(newJoin, right, left); + if (multiJoin != null) { + if (topProject != null) { + multiJoin = projectFactory.createProject(multiJoin, + topProject.getChildExps(), + topProject.getRowType().getFieldNames()); + } + call.transformTo(multiJoin); return; } + } + + // This method tries to merge the join with its left child. The left + // child should be a join for this to happen. + private static RelNode mergeJoin(Join join, RelNode left, RelNode right) { + final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); // We check whether the join can be combined with any of its children final List<RelNode> newInputs = Lists.newArrayList(); final List<RexNode> newJoinFilters = Lists.newArrayList(); newJoinFilters.add(join.getCondition()); - final List<Pair<JoinRelType, RexNode>> joinSpecs = Lists.newArrayList(); - final List<ImmutableBitSet> projFields = Lists.newArrayList(); + final List<Pair<Pair<Integer,Integer>, JoinRelType>> joinSpecs = Lists.newArrayList(); // Left child - if (left instanceof Join || left instanceof MultiJoin) { + if (left instanceof Join || left instanceof HiveMultiJoin) { final RexNode leftCondition; + final List<Pair<Integer,Integer>> leftJoinInputs; + final List<JoinRelType> leftJoinTypes; if (left instanceof Join) { - leftCondition = ((Join) left).getCondition(); + Join hj = (Join) left; + leftCondition = hj.getCondition(); + leftJoinInputs = ImmutableList.of(Pair.of(0, 1)); + leftJoinTypes = ImmutableList.of(hj.getJoinType()); } else { - leftCondition = ((MultiJoin) left).getJoinFilter(); + HiveMultiJoin hmj = (HiveMultiJoin) left; + leftCondition = hmj.getCondition(); + leftJoinInputs = hmj.getJoinInputs(); + leftJoinTypes = hmj.getJoinTypes(); } boolean combinable = isCombinablePredicate(join, join.getCondition(), leftCondition); if (combinable) { newJoinFilters.add(leftCondition); - for (RelNode input : left.getInputs()) { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(input); + for (int i = 0; i < leftJoinInputs.size(); i++) { + joinSpecs.add(Pair.of(leftJoinInputs.get(i), leftJoinTypes.get(i))); } - } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(left); + newInputs.addAll(left.getInputs()); + } else { // The join operation in the child is not on the same keys + return null; } - } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(left); + } else { // The left child is not a join or multijoin operator + return null; } + final int numberLeftInputs = newInputs.size(); // Right child - if (right instanceof Join || right instanceof MultiJoin) { - final RexNode rightCondition; - if (right instanceof Join) { - rightCondition = shiftRightFilter(join, left, right, - ((Join) right).getCondition()); - } else { - rightCondition = shiftRightFilter(join, left, right, - ((MultiJoin) right).getJoinFilter()); - } - - boolean combinable = isCombinablePredicate(join, join.getCondition(), - rightCondition); - if (combinable) { - newJoinFilters.add(rightCondition); - for (RelNode input : right.getInputs()) { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(input); - } - } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(right); - } - } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(right); - } + newInputs.add(right); // If we cannot combine any of the children, we bail out if (newJoinFilters.size() == 1) { - return; + return null; + } + + final List<RelDataTypeField> systemFieldList = ImmutableList.of(); + List<List<RexNode>> joinKeyExprs = new ArrayList<List<RexNode>>(); + List<Integer> filterNulls = new ArrayList<Integer>(); + for (int i=0; i<newInputs.size(); i++) { + joinKeyExprs.add(new ArrayList<RexNode>()); + } + RexNode otherCondition = HiveRelOptUtil.splitJoinCondition(systemFieldList, newInputs, join.getCondition(), + joinKeyExprs, filterNulls, null); + // If there are remaining parts in the condition, we bail out + if (!otherCondition.isAlwaysTrue()) { + return null; + } + ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder(); + for (int i=0; i<newInputs.size(); i++) { + List<RexNode> partialCondition = joinKeyExprs.get(i); + if (!partialCondition.isEmpty()) { + keysInInputsBuilder.set(i); + } + } + // If we cannot merge, we bail out + ImmutableBitSet keysInInputs = keysInInputsBuilder.build(); + ImmutableBitSet leftReferencedInputs = + keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs)); + ImmutableBitSet rightReferencedInputs = + keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs, newInputs.size())); + if (join.getJoinType() != JoinRelType.INNER && + (leftReferencedInputs.cardinality() > 1 || rightReferencedInputs.cardinality() > 1)) { + return null; + } + // Otherwise, we add to the join specs + if (join.getJoinType() != JoinRelType.INNER) { + int leftInput = keysInInputs.nextSetBit(0); + int rightInput = keysInInputs.nextSetBit(numberLeftInputs); + joinSpecs.add(Pair.of(Pair.of(leftInput, rightInput), join.getJoinType())); + } else { + for (int i : leftReferencedInputs) { + for (int j : rightReferencedInputs) { + joinSpecs.add(Pair.of(Pair.of(i, j), join.getJoinType())); + } + } } + // We can now create a multijoin operator RexNode newCondition = RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newJoinFilters, false)); - final ImmutableMap<Integer, ImmutableIntList> newJoinFieldRefCountsMap = - addOnJoinFieldRefCounts(newInputs, - join.getRowType().getFieldCount(), - newCondition); - - List<RexNode> newPostJoinFilters = combinePostJoinFilters(join, left, right); - - RelNode multiJoin = - new MultiJoin( + return new HiveMultiJoin( join.getCluster(), newInputs, newCondition, join.getRowType(), - false, - Pair.right(joinSpecs), Pair.left(joinSpecs), - projFields, - newJoinFieldRefCountsMap, - RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true)); - - call.transformTo(multiJoin); + Pair.right(joinSpecs)); } private static boolean isCombinablePredicate(Join join, @@ -203,7 +248,7 @@ public class HiveJoinToMultiJoinRule extends RelOptRule { * @param rightFilter the filter originating from the right child * @return the adjusted right filter */ - private RexNode shiftRightFilter( + private static RexNode shiftRightFilter( Join joinRel, RelNode left, RelNode right, @@ -228,106 +273,4 @@ public class HiveJoinToMultiJoinRule extends RelOptRule { return rightFilter; } - /** - * Adds on to the existing join condition reference counts the references - * from the new join condition. - * - * @param multiJoinInputs inputs into the new MultiJoin - * @param nTotalFields total number of fields in the MultiJoin - * @param joinCondition the new join condition - * @param origJoinFieldRefCounts existing join condition reference counts - * - * @return Map containing the new join condition - */ - private ImmutableMap<Integer, ImmutableIntList> addOnJoinFieldRefCounts( - List<RelNode> multiJoinInputs, - int nTotalFields, - RexNode joinCondition) { - // count the input references in the join condition - int[] joinCondRefCounts = new int[nTotalFields]; - joinCondition.accept(new InputReferenceCounter(joinCondRefCounts)); - - // add on to the counts for each input into the MultiJoin the - // reference counts computed for the current join condition - final Map<Integer, int[]> refCountsMap = Maps.newHashMap(); - int nInputs = multiJoinInputs.size(); - int currInput = -1; - int startField = 0; - int nFields = 0; - for (int i = 0; i < nTotalFields; i++) { - if (joinCondRefCounts[i] == 0) { - continue; - } - while (i >= (startField + nFields)) { - startField += nFields; - currInput++; - assert currInput < nInputs; - nFields = - multiJoinInputs.get(currInput).getRowType().getFieldCount(); - } - int[] refCounts = refCountsMap.get(currInput); - if (refCounts == null) { - refCounts = new int[nFields]; - refCountsMap.put(currInput, refCounts); - } - refCounts[i - startField] += joinCondRefCounts[i]; - } - - final ImmutableMap.Builder<Integer, ImmutableIntList> builder = - ImmutableMap.builder(); - for (Map.Entry<Integer, int[]> entry : refCountsMap.entrySet()) { - builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue())); - } - return builder.build(); - } - - /** - * Combines the post-join filters from the left and right inputs (if they - * are MultiJoinRels) into a single AND'd filter. - * - * @param joinRel the original LogicalJoin - * @param left left child of the LogicalJoin - * @param right right child of the LogicalJoin - * @return combined post-join filters AND'd together - */ - private List<RexNode> combinePostJoinFilters( - Join joinRel, - RelNode left, - RelNode right) { - final List<RexNode> filters = Lists.newArrayList(); - if (right instanceof MultiJoin) { - final MultiJoin multiRight = (MultiJoin) right; - filters.add( - shiftRightFilter(joinRel, left, multiRight, - multiRight.getPostJoinFilter())); - } - - if (left instanceof MultiJoin) { - filters.add(((MultiJoin) left).getPostJoinFilter()); - } - - return filters; - } - - //~ Inner Classes ---------------------------------------------------------- - - /** - * Visitor that keeps a reference count of the inputs used by an expression. - */ - private class InputReferenceCounter extends RexVisitorImpl<Void> { - private final int[] refCounts; - - public InputReferenceCounter(int[] refCounts) { - super(true); - this.refCounts = refCounts; - } - - public Void visitInputRef(RexInputRef inputRef) { - refCounts[inputRef.getIndex()]++; - return null; - } - } } - -// End JoinToMultiJoinRule.java - http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveProjectMergeRule.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveProjectMergeRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveProjectMergeRule.java index 8b90a15..9199b03 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveProjectMergeRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveProjectMergeRule.java @@ -20,7 +20,6 @@ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; import org.apache.calcite.rel.rules.ProjectMergeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; -//Currently not used, turn this on later public class HiveProjectMergeRule extends ProjectMergeRule { public static final HiveProjectMergeRule INSTANCE = new HiveProjectMergeRule(); http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index 3d1a309..f72f67f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -17,18 +17,32 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; +import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexPermuteInputsShuttle; +import org.apache.calcite.rex.RexVisitor; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.RelFieldTrimmer; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Util; +import org.apache.calcite.util.mapping.IntPair; +import org.apache.calcite.util.mapping.Mapping; +import org.apache.calcite.util.mapping.MappingType; +import org.apache.calcite.util.mapping.Mappings; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import com.google.common.collect.ImmutableList; @@ -50,6 +64,98 @@ public class HiveRelFieldTrimmer extends RelFieldTrimmer { semiJoinFactory, sortFactory, aggregateFactory, setOpFactory); } + /** + * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for + * {@link org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin}. + */ + public TrimResult trimFields( + HiveMultiJoin join, + ImmutableBitSet fieldsUsed, + Set<RelDataTypeField> extraFields) { + final int fieldCount = join.getRowType().getFieldCount(); + final RexNode conditionExpr = join.getCondition(); + + // Add in fields used in the condition. + final Set<RelDataTypeField> combinedInputExtraFields = + new LinkedHashSet<RelDataTypeField>(extraFields); + RelOptUtil.InputFinder inputFinder = + new RelOptUtil.InputFinder(combinedInputExtraFields); + inputFinder.inputBitSet.addAll(fieldsUsed); + conditionExpr.accept(inputFinder); + final ImmutableBitSet fieldsUsedPlus = inputFinder.inputBitSet.build(); + + int inputStartPos = 0; + int changeCount = 0; + int newFieldCount = 0; + List<RelNode> newInputs = new ArrayList<RelNode>(); + List<Mapping> inputMappings = new ArrayList<Mapping>(); + for (RelNode input : join.getInputs()) { + final RelDataType inputRowType = input.getRowType(); + final int inputFieldCount = inputRowType.getFieldCount(); + + // Compute required mapping. + ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder(); + for (int bit : fieldsUsedPlus) { + if (bit >= inputStartPos && bit < inputStartPos + inputFieldCount) { + inputFieldsUsed.set(bit - inputStartPos); + } + } + + Set<RelDataTypeField> inputExtraFields = + Collections.<RelDataTypeField>emptySet(); + TrimResult trimResult = + trimChild(join, input, inputFieldsUsed.build(), inputExtraFields); + newInputs.add(trimResult.left); + if (trimResult.left != input) { + ++changeCount; + } + + final Mapping inputMapping = trimResult.right; + inputMappings.add(inputMapping); + + // Move offset to point to start of next input. + inputStartPos += inputFieldCount; + newFieldCount += inputMapping.getTargetCount(); + } + + Mapping mapping = + Mappings.create( + MappingType.INVERSE_SURJECTION, + fieldCount, + newFieldCount); + int offset = 0; + int newOffset = 0; + for (int i = 0; i < inputMappings.size(); i++) { + Mapping inputMapping = inputMappings.get(i); + for (IntPair pair : inputMapping) { + mapping.set(pair.source + offset, pair.target + newOffset); + } + offset += inputMapping.getSourceCount(); + newOffset += inputMapping.getTargetCount(); + } + + if (changeCount == 0 + && mapping.isIdentity()) { + return new TrimResult(join, Mappings.createIdentity(fieldCount)); + } + + // Build new join. + final RexVisitor<RexNode> shuttle = new RexPermuteInputsShuttle( + mapping, newInputs.toArray(new RelNode[newInputs.size()])); + RexNode newConditionExpr = conditionExpr.accept(shuttle); + + final RelDataType newRowType = RelOptUtil.permute(join.getCluster().getTypeFactory(), + join.getRowType(), mapping); + final RelNode newJoin = new HiveMultiJoin(join.getCluster(), + newInputs, + newConditionExpr, + newRowType, + join.getJoinInputs(), + join.getJoinTypes()); + + return new TrimResult(newJoin, mapping); + } + protected TrimResult trimChild( RelNode rel, RelNode input, http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index a75d029..84c6cc8 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -32,7 +32,7 @@ import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistribution.Type; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; @@ -55,11 +55,13 @@ import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.io.AcidUtils.Operation; import org.apache.hadoop.hive.ql.metadata.VirtualColumn; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinLeafPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSemiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort; @@ -81,6 +83,7 @@ import org.apache.hadoop.hive.ql.parse.WindowingSpec; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils; +import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.plan.FilterDesc; import org.apache.hadoop.hive.ql.plan.JoinCondDesc; import org.apache.hadoop.hive.ql.plan.JoinDesc; @@ -153,8 +156,8 @@ public class HiveOpConverter { return visit((HiveTableScan) rn); } else if (rn instanceof HiveProject) { return visit((HiveProject) rn); - } else if (rn instanceof MultiJoin) { - return visit((MultiJoin) rn); + } else if (rn instanceof HiveMultiJoin) { + return visit((HiveMultiJoin) rn); } else if (rn instanceof HiveJoin) { return visit((HiveJoin) rn); } else if (rn instanceof HiveSemiJoin) { @@ -300,7 +303,7 @@ public class HiveOpConverter { return new OpAttr(inputOpAf.tabAlias, colInfoVColPair.getValue(), selOp); } - OpAttr visit(MultiJoin joinRel) throws SemanticException { + OpAttr visit(HiveMultiJoin joinRel) throws SemanticException { return translateJoin(joinRel); } @@ -326,48 +329,68 @@ public class HiveOpConverter { baseSrc[i] = inputs[i].tabAlias; } + // 2. Generate tags + for (int tag=0; tag<children.size(); tag++) { + ReduceSinkOperator reduceSinkOp = (ReduceSinkOperator) children.get(tag); + reduceSinkOp.getConf().setTag(tag); + } + + // 3. Virtual columns + Set<Integer> newVcolsInCalcite = new HashSet<Integer>(); + newVcolsInCalcite.addAll(inputs[0].vcolsInCalcite); + if (joinRel instanceof HiveMultiJoin || + extractJoinType((HiveJoin)joinRel) != JoinType.LEFTSEMI) { + int shift = inputs[0].inputs.get(0).getSchema().getSignature().size(); + for (int i = 1; i < inputs.length; i++) { + newVcolsInCalcite.addAll(HiveCalciteUtil.shiftVColsSet(inputs[i].vcolsInCalcite, shift)); + shift += inputs[i].inputs.get(0).getSchema().getSignature().size(); + } + } + if (LOG.isDebugEnabled()) { LOG.debug("Translating operator rel#" + joinRel.getId() + ":" + joinRel.getRelTypeName() + " with row type: [" + joinRel.getRowType() + "]"); } - // 2. Convert join condition JoinPredicateInfo joinPredInfo; if (joinRel instanceof HiveJoin) { - joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((HiveJoin)joinRel); + joinPredInfo = ((HiveJoin)joinRel).getJoinPredicateInfo(); } else { - joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((MultiJoin)joinRel); + joinPredInfo = ((HiveMultiJoin)joinRel).getJoinPredicateInfo(); } - // 3. Extract join key expressions from HiveSortExchange + // 4. Extract join key expressions from HiveSortExchange ExprNodeDesc[][] joinExpressions = new ExprNodeDesc[inputs.length][]; for (int i = 0; i < inputs.length; i++) { joinExpressions[i] = ((HiveSortExchange) joinRel.getInput(i)).getJoinExpressions(); } - // 4.a Generate tags - for (int tag=0; tag<children.size(); tag++) { - ReduceSinkOperator reduceSinkOp = (ReduceSinkOperator) children.get(tag); - reduceSinkOp.getConf().setTag(tag); - } - // 4.b Generate Join operator - JoinOperator joinOp = genJoin(joinRel, joinPredInfo, children, joinExpressions, baseSrc, tabAlias); - - // 5. TODO: Extract condition for non-equi join elements (if any) and - // add it - - // 6. Virtual columns - Set<Integer> newVcolsInCalcite = new HashSet<Integer>(); - newVcolsInCalcite.addAll(inputs[0].vcolsInCalcite); - if (joinRel instanceof MultiJoin || - extractJoinType((HiveJoin)joinRel) != JoinType.LEFTSEMI) { - int shift = inputs[0].inputs.get(0).getSchema().getSignature().size(); - for (int i = 1; i < inputs.length; i++) { - newVcolsInCalcite.addAll(HiveCalciteUtil.shiftVColsSet(inputs[i].vcolsInCalcite, shift)); - shift += inputs[i].inputs.get(0).getSchema().getSignature().size(); + // 5. Extract rest of join predicate info. We infer the rest of join condition + // that will be added to the filters (join conditions that are not part of + // the join key) + ExprNodeDesc[][] filterExpressions = new ExprNodeDesc[inputs.length][]; + for (int i = 0; i< inputs.length; i++) { + List<ExprNodeDesc> filterExpressionsForInput = new ArrayList<ExprNodeDesc>(); + Set<String> keySet = new HashSet<String>(); + for (int j = 0; j < joinPredInfo.getNonEquiJoinPredicateElements().size(); j++) { + JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo. + getNonEquiJoinPredicateElements().get(j); + for (RexNode joinExprNode : joinLeafPredInfo.getJoinExprs(i)) { + if (keySet.add(joinExprNode.toString())) { + ExprNodeDesc expr = convertToExprNode(joinExprNode, joinRel, + null, newVcolsInCalcite); + filterExpressionsForInput.add(expr); + } + } } + filterExpressions[i] = filterExpressionsForInput.toArray( + new ExprNodeDesc[filterExpressionsForInput.size()]); } + // 6. Generate Join operator + JoinOperator joinOp = genJoin(joinRel, joinExpressions, filterExpressions, children, + baseSrc, tabAlias); + // 7. Return result return new OpAttr(tabAlias, newVcolsInCalcite, joinOp); } @@ -798,20 +821,32 @@ public class HiveOpConverter { return rsOp; } - private static JoinOperator genJoin(RelNode join, JoinPredicateInfo joinPredInfo, - List<Operator<?>> children, ExprNodeDesc[][] joinKeys, String[] baseSrc, String tabAlias) throws SemanticException { - - // Extract join type - JoinType joinType; - if (join instanceof MultiJoin) { - joinType = JoinType.INNER; + private static JoinOperator genJoin(RelNode join, ExprNodeDesc[][] joinExpressions, + ExprNodeDesc[][] filterExpressions, List<Operator<?>> children, + String[] baseSrc, String tabAlias) throws SemanticException { + + // 1. Extract join type + JoinCondDesc[] joinCondns; + boolean semiJoin; + boolean noOuterJoin; + if (join instanceof HiveMultiJoin) { + HiveMultiJoin hmj = (HiveMultiJoin) join; + joinCondns = new JoinCondDesc[hmj.getJoinInputs().size()]; + for (int i = 0; i < hmj.getJoinInputs().size(); i++) { + joinCondns[i] = new JoinCondDesc(new JoinCond( + hmj.getJoinInputs().get(i).left, + hmj.getJoinInputs().get(i).right, + transformJoinType(hmj.getJoinTypes().get(i)))); + } + semiJoin = false; + noOuterJoin = !hmj.isOuterJoin(); } else { - joinType = extractJoinType((HiveJoin)join); - } - - JoinCondDesc[] joinCondns = new JoinCondDesc[children.size()-1]; - for (int i=1; i<children.size(); i++) { - joinCondns[i-1] = new JoinCondDesc(new JoinCond(0, i, joinType)); + joinCondns = new JoinCondDesc[1]; + JoinType joinType = extractJoinType((HiveJoin)join); + joinCondns[0] = new JoinCondDesc(new JoinCond(0, 1, joinType)); + semiJoin = joinType == JoinType.LEFTSEMI; + noOuterJoin = joinType != JoinType.FULLOUTER && joinType != JoinType.LEFTOUTER + && joinType != JoinType.RIGHTOUTER; } ArrayList<ColumnInfo> outputColumns = new ArrayList<ColumnInfo>(); @@ -820,12 +855,14 @@ public class HiveOpConverter { Operator<?>[] childOps = new Operator[children.size()]; Map<String, Byte> reversedExprs = new HashMap<String, Byte>(); - HashMap<Byte, List<ExprNodeDesc>> exprMap = new HashMap<Byte, List<ExprNodeDesc>>(); + Map<Byte, List<ExprNodeDesc>> exprMap = new HashMap<Byte, List<ExprNodeDesc>>(); + Map<Byte, List<ExprNodeDesc>> filters = new HashMap<Byte, List<ExprNodeDesc>>(); Map<String, ExprNodeDesc> colExprMap = new HashMap<String, ExprNodeDesc>(); HashMap<Integer, Set<String>> posToAliasMap = new HashMap<Integer, Set<String>>(); int outputPos = 0; for (int pos = 0; pos < children.size(); pos++) { + // 2. Backtracking from RS ReduceSinkOperator inputRS = (ReduceSinkOperator) children.get(pos); if (inputRS.getNumParent() != 1) { throw new SemanticException("RS should have single parent"); @@ -837,8 +874,8 @@ public class HiveOpConverter { Byte tag = (byte) rsDesc.getTag(); - // Semijoin - if (joinType == JoinType.LEFTSEMI && pos != 0) { + // 2.1. If semijoin... + if (semiJoin && pos != 0) { exprMap.put(tag, new ArrayList<ExprNodeDesc>()); childOps[pos] = inputRS; continue; @@ -865,20 +902,44 @@ public class HiveOpConverter { exprMap.put(tag, new ArrayList<ExprNodeDesc>(descriptors.values())); colExprMap.putAll(descriptors); childOps[pos] = inputRS; + + // 3. We populate the filters structure + List<ExprNodeDesc> filtersForInput = new ArrayList<ExprNodeDesc>(); + for (ExprNodeDesc expr : filterExpressions[pos]) { + if (expr instanceof ExprNodeGenericFuncDesc) { + ExprNodeGenericFuncDesc func = (ExprNodeGenericFuncDesc) expr; + List<ExprNodeDesc> newChildren = new ArrayList<ExprNodeDesc>(); + for (ExprNodeDesc functionChild : func.getChildren()) { + if (functionChild instanceof ExprNodeColumnDesc) { + newChildren.add(colExprMap.get(functionChild.getExprString())); + } else { + newChildren.add(functionChild); + } + } + func.setChildren(newChildren); + filtersForInput.add(expr); + } + else { + filtersForInput.add(expr); + } + } + filters.put(tag, filtersForInput); } - boolean noOuterJoin = joinType != JoinType.FULLOUTER && joinType != JoinType.LEFTOUTER - && joinType != JoinType.RIGHTOUTER; - JoinDesc desc = new JoinDesc(exprMap, outputColumnNames, noOuterJoin, joinCondns, joinKeys); + JoinDesc desc = new JoinDesc(exprMap, outputColumnNames, noOuterJoin, joinCondns, + filters, joinExpressions); desc.setReversedExprs(reversedExprs); + // 4. Create and populate filter map + int[][] filterMap = new int[joinExpressions.length][]; + + desc.setFilterMap(filterMap); + JoinOperator joinOp = (JoinOperator) OperatorFactory.getAndMakeChild(desc, new RowSchema( outputColumns), childOps); joinOp.setColumnExprMap(colExprMap); joinOp.setPosToAliasMap(posToAliasMap); - // TODO: null safes? - if (LOG.isDebugEnabled()) { LOG.debug("Generated " + joinOp + " with row schema: [" + joinOp.getSchema() + "]"); } @@ -914,6 +975,25 @@ public class HiveOpConverter { return resultJoinType; } + private static JoinType transformJoinType(JoinRelType type) { + JoinType resultJoinType; + switch (type) { + case FULL: + resultJoinType = JoinType.FULLOUTER; + break; + case LEFT: + resultJoinType = JoinType.LEFTOUTER; + break; + case RIGHT: + resultJoinType = JoinType.RIGHTOUTER; + break; + default: + resultJoinType = JoinType.INNER; + break; + } + return resultJoinType; + } + private static Map<String, ExprNodeDesc> buildBacktrackFromReduceSinkForJoin(int initialPos, List<String> outputColumnNames, List<String> keyColNames, List<String> valueColNames, int[] index, Operator<?> inputOp, String tabAlias) { @@ -957,7 +1037,12 @@ public class HiveOpConverter { } private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, String tabAlias, OpAttr inputAttr) { - return rn.accept(new ExprNodeConverter(tabAlias, inputRel.getRowType(), inputAttr.vcolsInCalcite, + return convertToExprNode(rn, inputRel, tabAlias, inputAttr.vcolsInCalcite); + } + + private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, String tabAlias, + Set<Integer> vcolsInCalcite) { + return rn.accept(new ExprNodeConverter(tabAlias, inputRel.getRowType(), vcolsInCalcite, inputRel.getCluster().getTypeFactory())); } http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index e821b1d..a73e24e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -64,7 +64,6 @@ import org.apache.calcite.rel.metadata.CachingRelMetadataProvider; import org.apache.calcite.rel.metadata.ChainedRelMetadataProvider; import org.apache.calcite.rel.metadata.RelMetadataProvider; import org.apache.calcite.rel.rules.FilterAggregateTransposeRule; -import org.apache.calcite.rel.rules.FilterMergeRule; import org.apache.calcite.rel.rules.FilterProjectTransposeRule; import org.apache.calcite.rel.rules.JoinPushTransitivePredicatesRule; import org.apache.calcite.rel.rules.JoinToMultiJoinRule; @@ -140,9 +139,11 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveFilterProjectTransp import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveFilterSetOpTransposeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveInsertExchange4JoinRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveJoinAddNotNullRule; +import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveJoinProjectTransposeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveJoinToMultiJoinRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HivePartitionPruneRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HivePreFilteringRule; +import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveProjectMergeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveRelFieldTrimmer; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveWindowingFixRule; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ASTConverter; @@ -860,7 +861,7 @@ public class CalcitePlanner extends SemanticAnalyzer { calciteOptimizedPlan = hepPlanner.findBestExp(); - // run rule to fix windowing issue when it is done over + // 4. Run rule to fix windowing issue when it is done over // aggregation columns (HIVE-10627) hepPgmBldr = new HepProgramBuilder().addMatchOrder(HepMatchOrder.BOTTOM_UP); hepPgmBldr.addRuleInstance(HiveWindowingFixRule.INSTANCE); @@ -870,16 +871,29 @@ public class CalcitePlanner extends SemanticAnalyzer { hepPlanner.setRoot(calciteOptimizedPlan); calciteOptimizedPlan = hepPlanner.findBestExp(); + // 5. Run rules to aid in translation from Calcite tree to Hive tree if (HiveConf.getBoolVar(conf, ConfVars.HIVE_CBO_RETPATH_HIVEOP)) { - // run rules to aid in translation from Optiq tree -> Hive tree + // 5.1. Merge join into multijoin operators (if possible) hepPgmBldr = new HepProgramBuilder().addMatchOrder(HepMatchOrder.BOTTOM_UP); hepPgmBldr.addRuleInstance(HiveJoinToMultiJoinRule.INSTANCE); + hepPgmBldr = hepPgmBldr.addRuleCollection(ImmutableList.of( + HiveJoinProjectTransposeRule.BOTH_PROJECT, + HiveJoinToMultiJoinRule.INSTANCE, + HiveProjectMergeRule.INSTANCE)); hepPlanner = new HepPlanner(hepPgmBldr.build()); hepPlanner.registerMetadataProviders(list); cluster.setMetadataProvider(new CachingRelMetadataProvider(chainedProvider, hepPlanner)); hepPlanner.setRoot(calciteOptimizedPlan); calciteOptimizedPlan = hepPlanner.findBestExp(); - + // The previous rules can pull up projections through join operators, + // thus we run the field trimmer again to push them back down + HiveRelFieldTrimmer fieldTrimmer = new HiveRelFieldTrimmer(null, HiveProject.DEFAULT_PROJECT_FACTORY, + HiveFilter.DEFAULT_FILTER_FACTORY, HiveJoin.HIVE_JOIN_FACTORY, + HiveSemiJoin.HIVE_SEMIJOIN_FACTORY, HiveSort.HIVE_SORT_REL_FACTORY, + HiveAggregate.HIVE_AGGR_REL_FACTORY, HiveUnion.UNION_REL_FACTORY); + calciteOptimizedPlan = fieldTrimmer.trim(calciteOptimizedPlan); + + // 5.2. Introduce exchange operators below join/multijoin operators hepPgmBldr = new HepProgramBuilder().addMatchOrder(HepMatchOrder.BOTTOM_UP); hepPgmBldr.addRuleInstance(HiveInsertExchange4JoinRule.EXCHANGE_BELOW_JOIN); hepPgmBldr.addRuleInstance(HiveInsertExchange4JoinRule.EXCHANGE_BELOW_MULTIJOIN); http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/java/org/apache/hadoop/hive/ql/plan/JoinDesc.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/plan/JoinDesc.java b/ql/src/java/org/apache/hadoop/hive/ql/plan/JoinDesc.java index 37012b4..3a4ea2f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/plan/JoinDesc.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/plan/JoinDesc.java @@ -109,13 +109,6 @@ public class JoinDesc extends AbstractOperatorDesc { } public JoinDesc(final Map<Byte, List<ExprNodeDesc>> exprs, - List<String> outputColumnNames, final boolean noOuterJoin, - final JoinCondDesc[] conds, ExprNodeDesc[][] joinKeys) { - this (exprs, outputColumnNames, noOuterJoin, conds, - new HashMap<Byte, List<ExprNodeDesc>>(), joinKeys); - } - - public JoinDesc(final Map<Byte, List<ExprNodeDesc>> exprs, List<String> outputColumnNames, final boolean noOuterJoin, final JoinCondDesc[] conds, final Map<Byte, List<ExprNodeDesc>> filters, ExprNodeDesc[][] joinKeys) { http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/test/queries/clientpositive/cbo_rp_join0.q ---------------------------------------------------------------------- diff --git a/ql/src/test/queries/clientpositive/cbo_rp_join0.q b/ql/src/test/queries/clientpositive/cbo_rp_join0.q new file mode 100644 index 0000000..acfff75 --- /dev/null +++ b/ql/src/test/queries/clientpositive/cbo_rp_join0.q @@ -0,0 +1,26 @@ +set hive.cbo.enable=true; +set hive.exec.check.crossproducts=false; + +set hive.stats.fetch.column.stats=true; +set hive.auto.convert.join=false; + +-- SORT_QUERY_RESULTS +-- Merge join into multijoin operator 1 +explain select key, cbo_t1.c_int, cbo_t2.p, q from cbo_t1 join +(select cbo_t2.key as p, cbo_t2.c_int as q, c_float as r from cbo_t2) cbo_t2 on cbo_t1.key=p right outer join +(select key as a, c_int as b, cbo_t3.c_float as c from cbo_t3) cbo_t3 on cbo_t1.key=a; + +select key, cbo_t1.c_int, cbo_t2.p, q from cbo_t1 join +(select cbo_t2.key as p, cbo_t2.c_int as q, c_float as r from cbo_t2) cbo_t2 on cbo_t1.key=p right outer join +(select key as a, c_int as b, cbo_t3.c_float as c from cbo_t3) cbo_t3 on cbo_t1.key=a; + +-- Merge join into multijoin operator 2 +explain select key, c_int, cbo_t2.p, cbo_t2.q, cbo_t3.x, cbo_t4.b from cbo_t1 join +(select cbo_t2.key as p, cbo_t2.c_int as q, c_float as r from cbo_t2) cbo_t2 on cbo_t1.key=p right outer join +(select cbo_t3.key as x, cbo_t3.c_int as y, c_float as z from cbo_t3) cbo_t3 on cbo_t1.key=x left outer join +(select key as a, c_int as b, c_float as c from cbo_t1) cbo_t4 on cbo_t1.key=a; + +select key, c_int, cbo_t2.p, cbo_t2.q, cbo_t3.x, cbo_t4.b from cbo_t1 join +(select cbo_t2.key as p, cbo_t2.c_int as q, c_float as r from cbo_t2) cbo_t2 on cbo_t1.key=p right outer join +(select cbo_t3.key as x, cbo_t3.c_int as y, c_float as z from cbo_t3) cbo_t3 on cbo_t1.key=x left outer join +(select key as a, c_int as b, c_float as c from cbo_t1) cbo_t4 on cbo_t1.key=a; http://git-wip-us.apache.org/repos/asf/hive/blob/6d19df3a/ql/src/test/queries/clientpositive/cbo_rp_join1.q ---------------------------------------------------------------------- diff --git a/ql/src/test/queries/clientpositive/cbo_rp_join1.q b/ql/src/test/queries/clientpositive/cbo_rp_join1.q new file mode 100644 index 0000000..ce6abe4 --- /dev/null +++ b/ql/src/test/queries/clientpositive/cbo_rp_join1.q @@ -0,0 +1,22 @@ +set hive.auto.convert.join = true; + +CREATE TABLE myinput1(key int, value int); +LOAD DATA LOCAL INPATH '../../data/files/in3.txt' INTO TABLE myinput1; + +SET hive.optimize.bucketmapjoin = true; +SET hive.optimize.bucketmapjoin.sortedmerge = true; +SET hive.input.format = org.apache.hadoop.hive.ql.io.BucketizedHiveInputFormat; + +SET hive.outerjoin.supports.filters = false; + +EXPLAIN SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key = 40 AND b.key = 40; +SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key = 40 AND b.key = 40; + +EXPLAIN SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key = 40 AND a.value = 40 AND a.key = a.value AND b.key = 40; +SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key = 40 AND a.key = a.value AND b.key = 40; + +EXPLAIN SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key = 40 AND a.key = b.key AND b.key = 40; +SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key = 40 AND a.key = b.key AND b.key = 40; + +EXPLAIN SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key > 40 AND a.value > 50 AND a.key = a.value AND b.key > 40 AND b.value > 50 AND b.key = b.value; +SELECT sum(hash(a.key,a.value,b.key,b.value)) FROM myinput1 a FULL OUTER JOIN myinput1 b on a.key > 40 AND a.value > 50 AND a.key = a.value AND b.key > 40 AND b.value > 50 AND b.key = b.value;
