singhpk234 commented on a change in pull request #35789: URL: https://github.com/apache/spark/pull/35789#discussion_r829247097
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala ########## @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Insert a filter on one side of the join if the other side has a selective predicate. + * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something + * else in the future. + */ +object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + + // Wraps `expr` with a hash function if its byte size is larger than an integer. + private def mayWrapWithHash(expr: Expression): Expression = { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { + new Murmur3Hash(Seq(expr)) + } else { + expr + } + } + + private def injectFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) + if (conf.runtimeFilterBloomFilterEnabled) { + injectBloomFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } else { + injectInSubqueryFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } Review comment: [question] consider a scenario where both the conf's BF as well as Semi join is enabled 1. when for the left child we called the injectFilter we tried BF and BF couldn't be applied where as Semi join could be but wasn't tried hence we marked no changes in left where as left was capable enough to have this Semi join reduction. 2. consider the scenario where left we were not able to inject BF / SemiJoin so we turned to right child and tried apply BF which again couldn't be applied where SemiJoin was somehow, we could have missed applying semiJoin here as well. As per my understanding we will be missing case 1, 2 here. Am I missing something here ? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala ########## @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Insert a filter on one side of the join if the other side has a selective predicate. + * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something + * else in the future. + */ +object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + + // Wraps `expr` with a hash function if its byte size is larger than an integer. + private def mayWrapWithHash(expr: Expression): Expression = { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { + new Murmur3Hash(Seq(expr)) + } else { + expr + } + } + + private def injectFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) + if (conf.runtimeFilterBloomFilterEnabled) { + injectBloomFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } else { + injectInSubqueryFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } + } + + private def injectBloomFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + // Skip if the filter creation side is too big + if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) { + return filterApplicationSidePlan + } + val rowCount = filterCreationSidePlan.stats.rowCount + val bloomFilterAgg = + if (rowCount.isDefined && rowCount.get.longValue > 0L) { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), + Literal(rowCount.get.longValue)) + } else { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) + } + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() + val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan)) + val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) + val filter = BloomFilterMightContain(bloomFilterSubquery, + new XxHash64(Seq(filterApplicationSideExp))) + Filter(filter, filterApplicationSidePlan) + } + + private def injectInSubqueryFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) + val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) + val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + if (!canBroadcastBySize(aggregate, conf)) { + // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, + // i.e., the semi-join will be a shuffled join, which is not worthwhile. + return filterApplicationSidePlan + } + val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), + ListQuery(aggregate, childOutputs = aggregate.output)) + Filter(filter, filterApplicationSidePlan) + } + + /** + * Returns whether the plan is a simple filter over scan and the filter is likely selective + * Also check if the plan only has simple expressions (attribute reference, literals) so that we + * do not add a subquery that might have an expensive computation + */ + private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) + case _ => false + } + !plan.isStreaming && ret + } + + private def isSimpleExpression(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) + } + + private def canFilterLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + + private def canFilterRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter => true + case _ => false + } + + private def isProbablyShuffleJoin(left: LogicalPlan, + right: LogicalPlan, hint: JoinHint): Boolean = { + !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) + } + + private def probablyHasShuffle(plan: LogicalPlan): Boolean = { + plan.collectFirst { + case j@Join(left, right, _, _, hint) + if isProbablyShuffleJoin(left, right, hint) => j + case a: Aggregate => a + }.nonEmpty + } + + // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`. + private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = { + val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES) + filterApplicationSide.collect({ + case leaf: LeafNode => leaf + }).map(scan => { + // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid + // creating a Bloom filter when the filter application side is very small, so using 0 + // as the byte size when the actual size is unknown can avoid regression by applying BF + // on a small table. + if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes + }).max + } + + // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a + // Bloom filter; false otherwise. + private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = { + // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom + // filter to each small task might be more costly than scanning them itself. Thus, we use max + // rather than sum here. + val maxScanSize = maxScanByteSize(filterApplicationSide) + maxScanSize >= + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD) + } + + /** + * Check that: + * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + * - expression references originate from a single leaf node) + * - The filter creation side has a selective predicate + * - The current join is a shuffle join or a broadcast join that has a shuffle below it + * - The max filterApplicationSide scan size is greater than a configurable threshold + */ + private def filteringHasBenefit( + filterApplicationSide: LogicalPlan, + filterCreationSide: LogicalPlan, + filterApplicationSideExp: Expression, + hint: JoinHint): Boolean = { + // Check that: + // 1. The filterApplicationSideJoinExp can be pushed down through joins and aggregates (i.e the + // expression references originate from a single leaf node) + // 2. The filter creation side has a selective predicate + // 3. The current join is a shuffle join or a broadcast join that has a shuffle or aggregate + // in the filter application side + // 4. The max filterApplicationSide scan size is greater than a configurable threshold + findExpressionAndTrackLineageDown(filterApplicationSideExp, + filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && + (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || + probablyHasShuffle(filterApplicationSide)) && + satisfyByteSizeRequirement(filterApplicationSide) + } + + def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + if (conf.runtimeFilterBloomFilterEnabled) { + hasBloomFilter(left, right, leftKey, rightKey) + } else { + hasInSubquery(left, right, leftKey, rightKey) + } Review comment: [question] what would be the behaviour when both the conf's are enabled I believe we could structure this to something like : ```suggestion val isRuntimeFilterPresent = false; if (conf.runtimeFilterBloomFilterEnabled) { isRuntimeFilterPresent = isRuntimeFilterPresent || hasBloomFilter(left, right, leftKey, rightKey) } if (conf.runtimeFilterSemiJoinReductionEnabled) { isRuntimeFilterPresent = isRuntimeFilterPresent || hasInSubquery(left, right, leftKey, rightKey) } isRuntimeFilterPresent ``` WDYT ? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala ########## @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Insert a filter on one side of the join if the other side has a selective predicate. + * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something + * else in the future. + */ +object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + + // Wraps `expr` with a hash function if its byte size is larger than an integer. + private def mayWrapWithHash(expr: Expression): Expression = { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { + new Murmur3Hash(Seq(expr)) + } else { + expr + } + } + + private def injectFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) + if (conf.runtimeFilterBloomFilterEnabled) { + injectBloomFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } else { + injectInSubqueryFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } + } + + private def injectBloomFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + // Skip if the filter creation side is too big + if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) { + return filterApplicationSidePlan + } + val rowCount = filterCreationSidePlan.stats.rowCount + val bloomFilterAgg = + if (rowCount.isDefined && rowCount.get.longValue > 0L) { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), + Literal(rowCount.get.longValue)) + } else { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) + } + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() + val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan)) + val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) + val filter = BloomFilterMightContain(bloomFilterSubquery, + new XxHash64(Seq(filterApplicationSideExp))) + Filter(filter, filterApplicationSidePlan) + } + + private def injectInSubqueryFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) + val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) + val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + if (!canBroadcastBySize(aggregate, conf)) { + // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, + // i.e., the semi-join will be a shuffled join, which is not worthwhile. + return filterApplicationSidePlan + } + val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), + ListQuery(aggregate, childOutputs = aggregate.output)) + Filter(filter, filterApplicationSidePlan) + } + + /** + * Returns whether the plan is a simple filter over scan and the filter is likely selective + * Also check if the plan only has simple expressions (attribute reference, literals) so that we + * do not add a subquery that might have an expensive computation + */ + private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) + case _ => false + } + !plan.isStreaming && ret + } + + private def isSimpleExpression(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) + } + + private def canFilterLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + + private def canFilterRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter => true + case _ => false + } + + private def isProbablyShuffleJoin(left: LogicalPlan, + right: LogicalPlan, hint: JoinHint): Boolean = { + !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) + } + + private def probablyHasShuffle(plan: LogicalPlan): Boolean = { + plan.collectFirst { + case j@Join(left, right, _, _, hint) + if isProbablyShuffleJoin(left, right, hint) => j + case a: Aggregate => a + }.nonEmpty + } + + // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`. + private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = { + val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES) + filterApplicationSide.collect({ + case leaf: LeafNode => leaf + }).map(scan => { + // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid + // creating a Bloom filter when the filter application side is very small, so using 0 + // as the byte size when the actual size is unknown can avoid regression by applying BF + // on a small table. + if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes + }).max + } + + // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a + // Bloom filter; false otherwise. + private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = { + // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom + // filter to each small task might be more costly than scanning them itself. Thus, we use max + // rather than sum here. + val maxScanSize = maxScanByteSize(filterApplicationSide) + maxScanSize >= + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD) + } + + /** + * Check that: + * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + * - expression references originate from a single leaf node) + * - The filter creation side has a selective predicate + * - The current join is a shuffle join or a broadcast join that has a shuffle below it + * - The max filterApplicationSide scan size is greater than a configurable threshold + */ + private def filteringHasBenefit( + filterApplicationSide: LogicalPlan, + filterCreationSide: LogicalPlan, + filterApplicationSideExp: Expression, + hint: JoinHint): Boolean = { + // Check that: + // 1. The filterApplicationSideJoinExp can be pushed down through joins and aggregates (i.e the + // expression references originate from a single leaf node) + // 2. The filter creation side has a selective predicate + // 3. The current join is a shuffle join or a broadcast join that has a shuffle or aggregate + // in the filter application side + // 4. The max filterApplicationSide scan size is greater than a configurable threshold + findExpressionAndTrackLineageDown(filterApplicationSideExp, + filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && + (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || + probablyHasShuffle(filterApplicationSide)) && + satisfyByteSizeRequirement(filterApplicationSide) + } + + def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + if (conf.runtimeFilterBloomFilterEnabled) { + hasBloomFilter(left, right, leftKey, rightKey) + } else { + hasInSubquery(left, right, leftKey, rightKey) + } Review comment: [question] what would be the behaviour when both the conf's are enabled I believe we should check both the runtime filters presence then WDYT ? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala ########## @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Insert a filter on one side of the join if the other side has a selective predicate. + * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something + * else in the future. + */ +object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + + // Wraps `expr` with a hash function if its byte size is larger than an integer. + private def mayWrapWithHash(expr: Expression): Expression = { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { + new Murmur3Hash(Seq(expr)) + } else { + expr + } + } + + private def injectFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) + if (conf.runtimeFilterBloomFilterEnabled) { + injectBloomFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } else { + injectInSubqueryFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } + } + + private def injectBloomFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + // Skip if the filter creation side is too big + if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterBloomFilterThreshold) { + return filterApplicationSidePlan + } + val rowCount = filterCreationSidePlan.stats.rowCount + val bloomFilterAgg = + if (rowCount.isDefined && rowCount.get.longValue > 0L) { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), + Literal(rowCount.get.longValue)) + } else { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) + } + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() + val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan)) + val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) + val filter = BloomFilterMightContain(bloomFilterSubquery, + new XxHash64(Seq(filterApplicationSideExp))) + Filter(filter, filterApplicationSidePlan) + } + + private def injectInSubqueryFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) + val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) + val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + if (!canBroadcastBySize(aggregate, conf)) { + // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, + // i.e., the semi-join will be a shuffled join, which is not worthwhile. + return filterApplicationSidePlan + } + val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), + ListQuery(aggregate, childOutputs = aggregate.output)) + Filter(filter, filterApplicationSidePlan) + } + + /** + * Returns whether the plan is a simple filter over scan and the filter is likely selective + * Also check if the plan only has simple expressions (attribute reference, literals) so that we + * do not add a subquery that might have an expensive computation + */ + private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { + plan.expressions + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) + case _ => false + } + !plan.isStreaming && ret + } + + private def isSimpleExpression(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) + } + + private def canFilterLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + + private def canFilterRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter => true + case _ => false + } + + private def isProbablyShuffleJoin(left: LogicalPlan, + right: LogicalPlan, hint: JoinHint): Boolean = { + !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) + } + + private def probablyHasShuffle(plan: LogicalPlan): Boolean = { + plan.collect { + case j@Join(left, right, _, _, hint) + if !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) => j + case a: Aggregate => a + }.nonEmpty + } + + // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`. + private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = { + val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES) + filterApplicationSide.collect({ + case leaf: LeafNode => leaf + }).map(scan => { + // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid + // creating a Bloom filter when the filter application side is very small, so using 0 + // as the byte size when the actual size is unknown can avoid regression by applying BF + // on a small table. + if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes + }).max Review comment: can we re-write this to : ```suggestion filterApplicationSide.collect({ case leaf: LeafNode && leaf.stats.sizeInBytes != defaultSizeInBytes => leaf.stats.sizeInBytes }).max ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
