This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch feature/crossjoin-array-contains-benchmark in repository https://gitbox.apache.org/repos/asf/spark.git
commit af9bfec8e91f0192f00b06f520b18d7c7c44db11 Author: Kent Yao <[email protected]> AuthorDate: Wed Feb 4 09:45:58 2026 +0000 [SPARK-XXXX][SQL] Add CrossJoinArrayContainsToInnerJoin optimizer rule ### What changes were proposed in this pull request? This PR adds a new optimizer rule that converts cross joins with array_contains filter into inner joins using explode, improving query performance significantly. ### Why are the changes needed? Cross joins with array_contains predicates result in O(N*M) complexity. By transforming to explode + inner join, we achieve O(N+M) complexity. ### Does this PR introduce _any_ user-facing change? No. This is an internal optimization that automatically applies to applicable queries. ### How was this patch tested? - Unit tests in CrossJoinArrayContainsToInnerJoinSuite (6 tests) - Microbenchmark showing 11-16X speedup on representative workload ### Was this patch authored or co-authored using generative AI tooling? Yes, GitHub Copilot was used to assist with implementation. --- .../CrossJoinArrayContainsToInnerJoin.scala | 129 ++++++++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 1 + .../CrossJoinArrayContainsToInnerJoinSuite.scala | 274 +++++++++++++++++++++ ...yContainsToInnerJoinBenchmark-jdk21-results.txt | 38 +++ ...inArrayContainsToInnerJoinBenchmark-results.txt | 38 +++ ...rossJoinArrayContainsToInnerJoinBenchmark.scala | 227 +++++++++++++++++ 6 files changed, 707 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala new file mode 100644 index 000000000000..5b9ce16786cc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, JOIN} +import org.apache.spark.sql.types._ + +/** + * Converts cross joins with array_contains filter into inner joins using explode. + * + * This optimization transforms queries of the form: + * {{{ + * SELECT * FROM left, right WHERE array_contains(left.arr, right.elem) + * }}} + * + * Into a more efficient form using explode + inner join, reducing O(N*M) to O(N+M). + */ +object CrossJoinArrayContainsToInnerJoin extends Rule[LogicalPlan] with PredicateHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAllPatterns(FILTER, JOIN)) { + case f @ Filter(cond, j @ Join(left, right, Cross | Inner, None, _)) => + tryTransform(f, cond, j, left, right).getOrElse(f) + } + + private def tryTransform( + filter: Filter, + condition: Expression, + join: Join, + left: LogicalPlan, + right: LogicalPlan): Option[LogicalPlan] = { + val predicates = splitConjunctivePredicates(condition) + val leftOut = left.outputSet + val rightOut = right.outputSet + + // Find first valid array_contains predicate + predicates.collectFirst { + case ac @ ArrayContains(arr, elem) + if canOptimize(arr, elem, leftOut, rightOut) => + val arrayOnLeft = arr.references.subsetOf(leftOut) + val remaining = predicates.filterNot(_ == ac) + buildPlan(join, left, right, arr, elem, arrayOnLeft, remaining) + }.flatten + } + + private def canOptimize( + arr: Expression, + elem: Expression, + leftOut: AttributeSet, + rightOut: AttributeSet): Boolean = { + // Check type compatibility + val elemType = elem.dataType + val validType = arr.dataType match { + case ArrayType(t, _) => t == elemType && isSupportedType(elemType) + case _ => false + } + + // Check array and element come from different sides + val arrRefs = arr.references + val elemRefs = elem.references + val crossesSides = (arrRefs.nonEmpty && elemRefs.nonEmpty) && ( + (arrRefs.subsetOf(leftOut) && elemRefs.subsetOf(rightOut)) || + (arrRefs.subsetOf(rightOut) && elemRefs.subsetOf(leftOut)) + ) + + validType && crossesSides + } + + /** + * Supported types have consistent equality semantics between array_contains and join. + * Excludes Float/Double (NaN issues) and complex types. + */ + private def isSupportedType(dt: DataType): Boolean = dt match { + case _: AtomicType => dt match { + case FloatType | DoubleType => false // NaN != NaN + case _ => true + } + case _ => false + } + + private def buildPlan( + join: Join, + left: LogicalPlan, + right: LogicalPlan, + arr: Expression, + elem: Expression, + arrayOnLeft: Boolean, + remaining: Seq[Expression]): Option[LogicalPlan] = { + + val unnestedAttr = AttributeReference("unnested", elem.dataType, nullable = true)() + val generator = Explode(ArrayDistinct(arr)) + + val (newLeft, newRight, joinCond) = if (arrayOnLeft) { + val gen = Generate(generator, Nil, false, None, Seq(unnestedAttr), left) + (gen, right, EqualTo(unnestedAttr, elem)) + } else { + val gen = Generate(generator, Nil, false, None, Seq(unnestedAttr), right) + (left, gen, EqualTo(elem, unnestedAttr)) + } + + val innerJoin = Join(newLeft, newRight, Inner, Some(joinCond), JoinHint.NONE) + + // Project to original output (exclude unnested column) + val projected = Project(join.output.map(a => Alias(a, a.name)(a.exprId)), innerJoin) + + // Add remaining predicates if any + val result = remaining.reduceLeftOption(And).map(Filter(_, projected)).getOrElse(projected) + Some(result) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fe15819bd44a..0a018bfe08a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -261,6 +261,7 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan), // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, + CrossJoinArrayContainsToInnerJoin, CheckCartesianProducts), Batch("RewriteSubquery", Once, RewritePredicateSubquery, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala new file mode 100644 index 000000000000..81fe449be71f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +/** + * Test suite for CrossJoinArrayContainsToInnerJoin optimizer rule. + * + * This rule converts cross joins with array_contains filter into inner joins + * using explode/unnest, which is much more efficient. + * + * Example transformation: + * {{{ + * Filter(array_contains(arr, elem)) + * CrossJoin(left, right) + * }}} + * becomes: + * {{{ + * InnerJoin(unnested = elem) + * Generate(Explode(ArrayDistinct(arr)), left) + * right + * }}} + */ +class CrossJoinArrayContainsToInnerJoinSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("CrossJoinArrayContainsToInnerJoin", Once, + CrossJoinArrayContainsToInnerJoin) :: Nil + } + + // Table with array column (simulates "orders" with item_ids array) + val ordersRelation: LocalRelation = LocalRelation( + $"order_id".int, + $"item_ids".array(IntegerType) + ) + + // Table with element column (simulates "items" with id) + val itemsRelation: LocalRelation = LocalRelation( + $"id".int, + $"name".string + ) + + test("converts cross join with array_contains to inner join with explode") { + // Original query: SELECT * FROM orders, items WHERE array_contains(item_ids, id) + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where(ArrayContains($"item_ids", $"id")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // After optimization, should be an inner join with explode + // The plan should NOT contain a Cross join anymore + assert(!optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "Optimized plan should not contain Cross join") + + // Should contain a Generate (explode) node + assert(optimized.exists { + case _: Generate => true + case _ => false + }, "Optimized plan should contain Generate (explode) node") + + // Should contain an Inner join + assert(optimized.exists { + case j: Join if j.joinType == Inner => true + case _ => false + }, "Optimized plan should contain Inner join") + } + + test("does not transform when array_contains is not present") { + // Query without array_contains: SELECT * FROM orders, items WHERE order_id = id + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where($"order_id" === $"id") + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should remain unchanged (still a cross join with filter) + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "Plan without array_contains should remain unchanged") + } + + test("does not transform inner join with existing conditions") { + // Already an inner join with equi-condition + val originalPlan = ordersRelation + .join(itemsRelation, Inner, Some($"order_id" === $"id")) + .where(ArrayContains($"item_ids", $"id")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should not add another explode since this is already an equi-join + // The array_contains becomes just a filter + assert(optimized.isInstanceOf[Filter] || optimized.exists { + case _: Filter => true + case _ => false + }) + } + + test("handles array column on right side of join") { + // Swap the tables - array is on right side + val rightWithArray: LocalRelation = LocalRelation( + $"arr_id".int, + $"values".array(IntegerType) + ) + val leftWithElement: LocalRelation = LocalRelation( + $"elem".int + ) + + val originalPlan = leftWithElement + .join(rightWithArray, Cross) + .where(ArrayContains($"values", $"elem")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should still be transformed + assert(!optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "Should transform even when array is on right side") + } + + test("preserves remaining filter predicates") { + // Query with additional conditions beyond array_contains + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where(ArrayContains($"item_ids", $"id") && ($"order_id" > 100)) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should still have a filter for the remaining predicate (order_id > 100) + assert(optimized.exists { + case Filter(cond, _) => + cond.find { + case GreaterThan(_, Literal(100, IntegerType)) => true + case _ => false + }.isDefined + case _ => false + }, "Should preserve remaining filter predicates") + } + + test("uses array_distinct to avoid duplicate matches") { + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where(ArrayContains($"item_ids", $"id")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // The optimized plan should use ArrayDistinct before exploding + // to avoid duplicate rows when array has duplicate elements + assert(optimized.exists { + case Generate(Explode(ArrayDistinct(_)), _, _, _, _, _) => true + case Project(_, Generate(Explode(ArrayDistinct(_)), _, _, _, _, _)) => true + case _ => false + }, "Should use ArrayDistinct before Explode") + } + + test("supports ByteType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(ByteType)) + val rightRel = LocalRelation($"elem".byte) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports ShortType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(ShortType)) + val rightRel = LocalRelation($"elem".short) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports DecimalType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(DecimalType(10, 2))) + val rightRel = LocalRelation($"elem".decimal(10, 2)) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports TimestampType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(TimestampType)) + val rightRel = LocalRelation($"elem".timestamp) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports BooleanType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(BooleanType)) + val rightRel = LocalRelation($"elem".boolean) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("does not transform FloatType elements due to NaN semantics") { + val leftRel = LocalRelation($"id".int, $"arr".array(FloatType)) + val rightRel = LocalRelation($"elem".float) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + // Should NOT be transformed - still contains Cross join + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "FloatType should not be transformed due to NaN semantics") + } + + test("does not transform DoubleType elements due to NaN semantics") { + val leftRel = LocalRelation($"id".int, $"arr".array(DoubleType)) + val rightRel = LocalRelation($"elem".double) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + // Should NOT be transformed - still contains Cross join + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "DoubleType should not be transformed due to NaN semantics") + } + + test("supports BinaryType elements") { + // BinaryType is safe because Spark's join uses content-based hash/comparison + // via ByteArray.compareBinary, not Java's Array.equals() + val leftRel = LocalRelation($"id".int, $"arr".array(BinaryType)) + val rightRel = LocalRelation($"elem".binary) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("does not transform StructType elements") { + val structType = StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))) + val leftRel = LocalRelation($"id".int, $"arr".array(structType)) + val rightRel = LocalRelation($"elem".struct(structType)) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "StructType should not be transformed due to complex equality semantics") + } +} diff --git a/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt new file mode 100644 index 000000000000..47b45cd26615 --- /dev/null +++ b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt @@ -0,0 +1,38 @@ +================================================================================================ +CrossJoinArrayContainsToInnerJoin Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (1000 orders, 100 items, array size 5): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 52 69 15 1.9 520.8 1.0X +Inner join with explode (optimized equivalent) 56 74 19 1.8 564.9 0.9X +Inner join with explode (DataFrame API) 39 41 2 2.5 393.2 1.3X + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (10000 orders, 1000 items, array size 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 582 596 19 17.2 58.2 1.0X +Inner join with explode (optimized equivalent) 36 39 3 276.2 3.6 16.1X +Inner join with explode (DataFrame API) 34 39 5 297.8 3.4 17.3X + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Scalability: varying array sizes: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +array_size=1 with explode optimization 143 151 7 7.0 143.5 1.0X +array_size=5 with explode optimization 145 146 1 6.9 145.4 1.0X +array_size=10 with explode optimization 144 150 10 6.9 144.3 1.0X +array_size=50 with explode optimization 142 152 15 7.0 142.1 1.0X + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Different data types in array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Integer array 31 39 7 32.4 30.9 1.0X +Long array 29 31 3 34.2 29.2 1.1X +String array 37 37 1 27.2 36.8 0.8X + + diff --git a/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt new file mode 100644 index 000000000000..8df73c946310 --- /dev/null +++ b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt @@ -0,0 +1,38 @@ +================================================================================================ +CrossJoinArrayContainsToInnerJoin Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (1000 orders, 100 items, array size 5): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 52 56 3 1.9 523.4 1.0X +Inner join with explode (optimized equivalent) 60 62 2 1.7 598.3 0.9X +Inner join with explode (DataFrame API) 44 47 3 2.3 440.5 1.2X + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (10000 orders, 1000 items, array size 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 504 533 25 19.8 50.4 1.0X +Inner join with explode (optimized equivalent) 45 45 0 221.9 4.5 11.2X +Inner join with explode (DataFrame API) 36 40 4 279.7 3.6 14.1X + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Scalability: varying array sizes: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +array_size=1 with explode optimization 144 146 2 6.9 144.2 1.0X +array_size=5 with explode optimization 145 146 1 6.9 145.4 1.0X +array_size=10 with explode optimization 142 157 17 7.0 142.0 1.0X +array_size=50 with explode optimization 139 141 2 7.2 138.7 1.0X + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Different data types in array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Integer array 29 33 6 34.2 29.3 1.0X +Long array 35 37 3 28.9 34.6 0.8X +String array 40 42 2 24.7 40.5 0.7X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala new file mode 100644 index 000000000000..2ac64d9099dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark to measure performance improvement of CrossJoinArrayContainsToInnerJoin optimization. + * + * This benchmark compares: + * 1. Cross join with array_contains filter (unoptimized) + * 2. Inner join with explode (manually optimized / what the rule produces) + * + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class <this class> + * --jars <spark core test jar>,<spark catalyst test jar> <spark sql test jar> + * 2. build/sbt "sql/Test/runMain <this class>" + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain <this class>" + * Results will be written to + * "benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt". + * }}} + */ +object CrossJoinArrayContainsToInnerJoinBenchmark extends SqlBasedBenchmark { + + import spark.implicits._ + + private def crossJoinWithArrayContains(numOrders: Int, numItems: Int, arraySize: Int): Unit = { + val benchmark = new Benchmark( + s"Cross join with array_contains ($numOrders orders, $numItems items, array size $arraySize)", + numOrders.toLong * numItems, + output = output + ) + + // Create orders table with array of item IDs + val orders = spark.range(numOrders) + .selectExpr( + "id as order_id", + s"array_repeat(cast((id % $numItems) as int), $arraySize) as item_ids" + ) + .cache() + + // Create items table + val items = spark.range(numItems) + .selectExpr("cast(id as int) as item_id", "concat('item_', id) as item_name") + .cache() + + // Force caching + orders.count() + items.count() + + // Register as temp views for SQL queries + orders.createOrReplaceTempView("orders") + items.createOrReplaceTempView("items") + + benchmark.addCase("Cross join + array_contains filter (unoptimized)", numIters = 3) { _ => + // Disable the optimization to measure the true cross-join+filter baseline + withSQLConf( + SQLConf.CROSS_JOINS_ENABLED.key -> "true", + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.CrossJoinArrayContainsToInnerJoin") { + // This query would be a cross join with filter without optimization + val df = spark.sql( + """ + |SELECT /*+ BROADCAST(items) */ o.order_id, i.item_id, i.item_name + |FROM orders o, items i + |WHERE array_contains(o.item_ids, i.item_id) + """.stripMargin) + df.noop() + } + } + + benchmark.addCase("Inner join with explode (optimized equivalent)", numIters = 3) { _ => + // This is what the optimization produces - explode + inner join + val df = spark.sql( + """ + |SELECT o.order_id, i.item_id, i.item_name + |FROM ( + | SELECT order_id, explode(array_distinct(item_ids)) as unnested_id + | FROM orders + |) o + |INNER JOIN items i ON o.unnested_id = i.item_id + """.stripMargin) + df.noop() + } + + benchmark.addCase("Inner join with explode (DataFrame API)", numIters = 3) { _ => + val ordersExploded = orders + .withColumn("unnested_id", explode(array_distinct($"item_ids"))) + .select($"order_id", $"unnested_id") + + val df = ordersExploded.join(items, $"unnested_id" === $"item_id") + df.noop() + } + + benchmark.run() + + orders.unpersist() + items.unpersist() + } + + private def scalabilityBenchmark(): Unit = { + val benchmark = new Benchmark( + "Scalability: varying array sizes", + 1000000L, + output = output + ) + + val numOrders = 10000 + val numItems = 1000 + + Seq(1, 5, 10, 50).foreach { arraySize => + val orders = spark.range(numOrders) + .selectExpr( + "id as order_id", + s"transform(sequence(0, $arraySize - 1), " + + s"x -> cast((id + x) % $numItems as int)) as item_ids" + ) + + val items = spark.range(numItems) + .selectExpr("cast(id as int) as item_id", "concat('item_', id) as item_name") + + orders.createOrReplaceTempView("orders_scale") + items.createOrReplaceTempView("items_scale") + + benchmark.addCase(s"array_size=$arraySize with explode optimization", numIters = 3) { _ => + val df = spark.sql( + """ + |SELECT o.order_id, i.item_id, i.item_name + |FROM ( + | SELECT order_id, explode(array_distinct(item_ids)) as unnested_id + | FROM orders_scale + |) o + |INNER JOIN items_scale i ON o.unnested_id = i.item_id + """.stripMargin) + df.noop() + } + } + + benchmark.run() + } + + private def dataTypeBenchmark(): Unit = { + val benchmark = new Benchmark( + "Different data types in array", + 1000000L, + output = output + ) + + val numRows = 10000 + val numLookup = 1000 + val arraySize = 10 + + // Integer arrays + benchmark.addCase("Integer array", numIters = 3) { _ => + val left = spark.range(numRows) + .selectExpr("id", s"array_repeat(cast(id % $numLookup as int), $arraySize) as arr") + val right = spark.range(numLookup).selectExpr("cast(id as int) as elem") + + val df = left + .withColumn("unnested", explode(array_distinct($"arr"))) + .join(right, $"unnested" === $"elem") + df.noop() + } + + // Long arrays + benchmark.addCase("Long array", numIters = 3) { _ => + val left = spark.range(numRows) + .selectExpr("id", s"array_repeat(id % $numLookup, $arraySize) as arr") + val right = spark.range(numLookup).selectExpr("id as elem") + + val df = left + .withColumn("unnested", explode(array_distinct($"arr"))) + .join(right, $"unnested" === $"elem") + df.noop() + } + + // String arrays + benchmark.addCase("String array", numIters = 3) { _ => + val left = spark.range(numRows) + .selectExpr("id", s"array_repeat(concat('key_', id % $numLookup), $arraySize) as arr") + val right = spark.range(numLookup).selectExpr("concat('key_', id) as elem") + + val df = left + .withColumn("unnested", explode(array_distinct($"arr"))) + .join(right, $"unnested" === $"elem") + df.noop() + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("CrossJoinArrayContainsToInnerJoin Benchmark") { + // Small scale test + crossJoinWithArrayContains(numOrders = 1000, numItems = 100, arraySize = 5) + + // Medium scale test + crossJoinWithArrayContains(numOrders = 10000, numItems = 1000, arraySize = 10) + + // Scalability test with varying array sizes + scalabilityBenchmark() + + // Data type comparison + dataTypeBenchmark() + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
