uros-b commented on code in PR #56291: URL: https://github.com/apache/spark/pull/56291#discussion_r3465944619
########## sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala: ########## @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * End-to-end correctness tests for the segment-tree shrinking-frame path + * (`... ROWS/RANGE BETWEEN <lower> AND UNBOUNDED FOLLOWING`). + * + * Mirrors the structure of [[SegmentTreeWindowFunctionSuite]]: every test + * runs the same SQL with `spark.sql.window.segmentTree.enabled` off and on + * and asserts row-set equality. The "off" path runs through + * [[UnboundedFollowingWindowFunctionFrame]] (the O(N^2) baseline); the "on" + * path runs through the new shrinking branch in + * [[SegmentTreeWindowFunctionFrame]] (`ubound = None`). + */ +class UnboundedFollowingSegmentTreeSuite extends SharedSparkSession { + + import testImplicits._ + + private val enableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1") + + private val disableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false") + + /** Baseline (flag off) vs segtree (flag on); compare row-sets. */ + private def checkEquivalence(build: () => DataFrame): Unit = { + val baseline: Seq[Row] = withSQLConf(disableSegTree.toSeq: _*) { + build().collect().toSeq + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = build().collect().toSeq + QueryTest.sameRows(baseline, actual, isSorted = false).foreach { err => + fail(s"shrinking-frame segtree output differs from baseline.\n$err") + } + } + } + + /** SQL-level variant that accepts a query string. */ + private def checkSqlEquivalence(df: DataFrame, query: String): Unit = { + df.createOrReplaceTempView("t") + try { + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + spark.sql(query).collect().sortBy(_.toString) + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = spark.sql(query).collect().sortBy(_.toString) + assert(actual.toSeq === baseline.toSeq, + s"shrinking-frame segtree output differs from baseline.\n" + + s"Expected: ${baseline.toSeq}\nActual: ${actual.toSeq}") + } + } finally { + spark.catalog.dropTempView("t") + } + } + + /** 3 partitions, 40 rows each; values = row index. */ + private def baseDF: DataFrame = + spark.range(0, 120).selectExpr( + "id", + "(id % 3) AS pk", + "CAST(id AS INT) AS v") + + /** Shrinking ROWS frame: [lo, end-of-partition). */ + private def shrinkingRowsFrame(lo: Int) = + Window.partitionBy($"pk").orderBy($"id") + .rowsBetween(lo, Window.unboundedFollowing) + + // ============================================================ + // ROWS frame: basic aggregate equivalence (CURRENT ROW lower) + // ============================================================ + + test("MIN over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", min($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("MAX over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", max($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("SUM over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("COUNT over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", count($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("AVG over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", avg($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + // ============================================================ + // ROWS frame: lower-bound variations + // ============================================================ + + test("ROWS BETWEEN 5 PRECEDING AND UNBOUNDED FOLLOWING (suffix + lookback)") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(-5)).as("agg"))) + } + + test("ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING is NOT this path") { + // Both-unbounded routes to UnboundedWindowFunctionFrame (different case + // in the dispatcher) and is one-shot O(1). This test just verifies the + // segtree flag doesn't break it. + val frame = Window.partitionBy($"pk").orderBy($"id") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(frame).as("agg"))) + } + + test("ROWS BETWEEN 5 FOLLOWING AND UNBOUNDED FOLLOWING (lower bound is positive)") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(5)).as("agg"))) + } + + // ============================================================ + // Multi-aggregate: shared frame + // ============================================================ + + test("MIN + MAX + SUM share a single shrinking frame") { + checkEquivalence(() => + baseDF.select( + $"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + sum($"v").over(shrinkingRowsFrame(0)).as("s"))) + } + + // ============================================================ + // Partition / boundary edge cases + // ============================================================ + + test("single-row partition") { + val df = spark.range(0, 5).selectExpr("id", "id AS pk", "CAST(id AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("empty result table (no rows)") { + val df = spark.emptyDataFrame.selectExpr("CAST(NULL AS BIGINT) AS id", + "CAST(NULL AS BIGINT) AS pk", "CAST(NULL AS INT) AS v") + .where("id IS NOT NULL") + checkEquivalence(() => + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("partition below minPartitionRows falls back to UnboundedFollowingWindowFunctionFrame") { + // With minRows=1024 the segtree path forces fallback; baseline (off) and + // forced-fallback (on, but min=1024) must match. The point is that the + // small-partition path goes through the legacy frame, not segtree. + val df = baseDF + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("s")) + .collect().toSeq + } + withSQLConf( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1024") { + val actual = df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("s")) + .collect().toSeq + QueryTest.sameRows(baseline, actual, isSorted = false).foreach { err => + fail(s"forced-fallback path diverges from baseline.\n$err") + } + } + } + Review Comment: Minor coverage gap: UnboundedFollowingSegmentTreeSuite.scala has no forced-fallback test for the RANGE shrinking path (the one fallback test, "partition below minPartitionRows", uses a ROWS frame). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
