Repository: spark Updated Branches: refs/heads/branch-2.0 4b38a6a53 -> 4391d4a3c
[SPARK-16633][SPARK-16642][SPARK-16721][SQL] Fixes three issues related to lead and lag functions ## What changes were proposed in this pull request? This PR contains three changes. First, this PR changes the behavior of lead/lag back to Spark 1.6's behavior, which is described as below: 1. lead/lag respect null input values, which means that if the offset row exists and the input value is null, the result will be null instead of the default value. 2. If the offset row does not exist, the default value will be used. 3. OffsetWindowFunction's nullable setting also considers the nullability of its input (because of the first change). Second, this PR fixes the evaluation of lead/lag when the input expression is a literal. This fix is a result of the first change. In current master, if a literal is used as the input expression of a lead or lag function, the result will be this literal even if the offset row does not exist. Third, this PR makes ResolveWindowFrame not fire if a window function is not resolved. ## How was this patch tested? New tests in SQLWindowFunctionSuite Author: Yin Huai <yh...@databricks.com> Closes #14284 from yhuai/lead-lag. (cherry picked from commit 815f3eece5f095919a329af8cbd762b9ed71c7a8) Signed-off-by: Yin Huai <yh...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4391d4a3 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4391d4a3 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4391d4a3 Branch: refs/heads/branch-2.0 Commit: 4391d4a3c60d59df625cbfdb918aa67c51ebcbc1 Parents: 4b38a6a Author: Yin Huai <yh...@databricks.com> Authored: Mon Jul 25 20:58:07 2016 -0700 Committer: Yin Huai <yh...@databricks.com> Committed: Mon Jul 25 20:58:57 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 +- .../expressions/windowExpressions.scala | 45 +- .../apache/spark/sql/execution/WindowExec.scala | 34 +- .../sql/execution/SQLWindowFunctionSuite.scala | 414 +++++++++++++++++++ .../hive/execution/SQLWindowFunctionSuite.scala | 370 ----------------- 5 files changed, 467 insertions(+), 399 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4391d4a3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d1d2c59..61162cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1787,7 +1787,8 @@ class Analyzer( s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => WindowExpression(wf, s.copy(frameSpecification = wf.frame)) - case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if e.resolved => val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) we.copy(windowSpec = s.copy(frameSpecification = frame)) } http://git-wip-us.apache.org/repos/asf/spark/blob/4391d4a3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index e35192c..6806591 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -321,8 +321,7 @@ abstract class OffsetWindowFunction val input: Expression /** - * Default result value for the function when the input expression returns NULL. The default will - * evaluated against the current row instead of the offset row. + * Default result value for the function when the 'offset'th row does not exist. */ val default: Expression @@ -348,7 +347,7 @@ abstract class OffsetWindowFunction */ override def foldable: Boolean = false - override def nullable: Boolean = default == null || default.nullable + override def nullable: Boolean = default == null || default.nullable || input.nullable override lazy val frame = { // This will be triggered by the Analyzer. @@ -373,20 +372,22 @@ abstract class OffsetWindowFunction } /** - * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window. - * Offsets start at 0, which is the current row. The offset must be constant integer value. The - * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger - * than the window, the default expression is evaluated. - * - * This documentation has been based upon similar documentation for the Hive and Presto projects. + * The Lead function returns the value of 'x' at the 'offset'th row after the current row in + * the window. Offsets start at 0, which is the current row. The offset must be constant + * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row, + * null is returned. If there is no such offset row, the default expression is evaluated. * * @param input expression to evaluate 'offset' rows after the current row. * @param offset rows to jump ahead in the partition. - * @param default to use when the input value is null or when the offset is larger than the window. + * @param default to use when the offset is larger than the window. The default value is null. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows - after the current row in the window""") + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at the 'offset'th row + after the current row in the window. + The default value of 'offset' is 1 and the default value of 'default' is null. + If the value of 'x' at the 'offset'th row is null, null is returned. + If there is no such offset row (e.g. when the offset is 1, the last row of the window + does not have any subsequent row), 'default' is returned.""") case class Lead(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -400,20 +401,22 @@ case class Lead(input: Expression, offset: Expression, default: Expression) } /** - * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window. - * Offsets start at 0, which is the current row. The offset must be constant integer value. The - * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller - * than the window, the default expression is evaluated. - * - * This documentation has been based upon similar documentation for the Hive and Presto projects. + * The Lag function returns the value of 'x' at the 'offset'th row before the current row in + * the window. Offsets start at 0, which is the current row. The offset must be constant + * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row, + * null is returned. If there is no such offset row, the default expression is evaluated. * * @param input expression to evaluate 'offset' rows before the current row. * @param offset rows to jump back in the partition. - * @param default to use when the input value is null or when the offset is smaller than the window. + * @param default to use when the offset row does not exist. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows - before the current row in the window""") + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at the 'offset'th row + before the current row in the window. + The default value of 'offset' is 1 and the default value of 'default' is null. + If the value of 'x' at the 'offset'th row is null, null is returned. + If there is no such offset row (e.g. when the offset is 1, the first row of the window + does not have any previous row), 'default' is returned.""") case class Lag(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { http://git-wip-us.apache.org/repos/asf/spark/blob/4391d4a3/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala index e01094a..3927a50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala @@ -582,25 +582,43 @@ private[execution] final class OffsetWindowFunctionFrame( /** Row used to combine the offset and the current row. */ private[this] val join = new JoinedRow - /** Create the projection. */ + /** + * Create the projection used when the offset row exists. + * Please note that this project always respect null input values (like PostgreSQL). + */ private[this] val projection = { // Collect the expressions and bind them. val inputAttrs = inputSchema.map(_.withNullability(true)) - val numInputAttributes = inputAttrs.size val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { case e: OffsetWindowFunction => val input = BindReferences.bindReference(e.input, inputAttrs) + input + case e => + BindReferences.bindReference(e, inputAttrs) + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil).target(target) + } + + /** Create the projection used when the offset row DOES NOT exists. */ + private[this] val fillDefaultValue = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val numInputAttributes = inputAttrs.size + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => if (e.default == null || e.default.foldable && e.default.eval() == null) { - // Without default value. - input + // The default value is null. + Literal.create(null, e.dataType) } else { - // With default value. + // The default value is an expression. val default = BindReferences.bindReference(e.default, inputAttrs).transform { // Shift the input reference to its default version. case BoundReference(o, dataType, nullable) => BoundReference(o + numInputAttributes, dataType, nullable) } - org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + default } case e => BindReferences.bindReference(e, inputAttrs) @@ -625,10 +643,12 @@ private[execution] final class OffsetWindowFunctionFrame( if (inputIndex >= 0 && inputIndex < input.size) { val r = input.next() join(r, current) + projection(join) } else { join(emptyRow, current) + // Use default values since the offset row does not exist. + fillDefaultValue(join) } - projection(join) inputIndex += 1 } } http://git-wip-us.apache.org/repos/asf/spark/blob/4391d4a3/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala new file mode 100644 index 0000000..d3cfa95 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.test.SharedSQLContext + +case class WindowData(month: Int, area: String, product: Int) + + +/** + * Test suite for SQL window functions. + */ +class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("window function: udaf with aggregate expression") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product), sum(sum(product)) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 5, 11), + ("a", 6, 11), + ("b", 7, 15), + ("b", 8, 15), + ("c", 9, 19), + ("c", 10, 19) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, sum(product) - 1, sum(sum(product)) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 4, 11), + ("a", 5, 11), + ("b", 6, 15), + ("b", 7, 15), + ("c", 8, 19), + ("c", 9, 19) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, sum(product), sum(product) / sum(sum(product)) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 5, 5d/11), + ("a", 6, 6d/11), + ("b", 7, 7d/15), + ("b", 8, 8d/15), + ("c", 10, 10d/19), + ("c", 9, 9d/19) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, sum(product), sum(product) / sum(sum(product) - 1) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 5, 5d/9), + ("a", 6, 6d/9), + ("b", 7, 7d/13), + ("b", 8, 8d/13), + ("c", 10, 10d/17), + ("c", 9, 9d/17) + ).map(i => Row(i._1, i._2, i._3))) + } + + test("window function: refer column in inner select block") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product, 1 as tmp1 from windowData) tmp + """.stripMargin), + Seq( + ("a", 2), + ("a", 3), + ("b", 2), + ("b", 3), + ("c", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + } + + test("window function: partition and order expressions") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") + + checkAnswer( + sql( + """ + |select month, area, product, sum(product + 1) over (partition by 1 order by 2) + |from windowData + """.stripMargin), + Seq( + (1, "a", 5, 51), + (2, "a", 6, 51), + (3, "b", 7, 51), + (4, "b", 8, 51), + (5, "c", 9, 51), + (6, "c", 10, 51) + ).map(i => Row(i._1, i._2, i._3, i._4))) + + checkAnswer( + sql( + """ + |select month, area, product, sum(product) + |over (partition by month % 2 order by 10 - product) + |from windowData + """.stripMargin), + Seq( + (1, "a", 5, 21), + (2, "a", 6, 24), + (3, "b", 7, 16), + (4, "b", 8, 18), + (5, "c", 9, 9), + (6, "c", 10, 10) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("window function: distinct should not be silently ignored") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") + + val e = intercept[AnalysisException] { + sql( + """ + |select month, area, product, sum(distinct product + 1) over (partition by 1 order by 2) + |from windowData + """.stripMargin) + } + assert(e.getMessage.contains("Distinct window functions are not supported")) + } + + test("window function: expressions in arguments of a window functions") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") + + checkAnswer( + sql( + """ + |select month, area, month % 2, + |lag(product, 1 + 1, product) over (partition by month % 2 order by area) + |from windowData + """.stripMargin), + Seq( + (1, "a", 1, 5), + (2, "a", 0, 6), + (3, "b", 1, 7), + (4, "b", 0, 8), + (5, "c", 1, 5), + (6, "c", 0, 6) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + + test("window function: Sorting columns are not in Project") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") + + checkAnswer( + sql("select month, product, sum(product + 1) over() from windowData order by area"), + Seq( + (2, 6, 57), + (3, 7, 57), + (4, 8, 57), + (5, 9, 57), + (6, 11, 57), + (1, 10, 57) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p + """.stripMargin), + Seq( + ("a", 2), + ("b", 2), + ("b", 3), + ("c", 2), + ("d", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by month) as c1 + |from windowData group by product, area, month order by product, area + """.stripMargin), + Seq( + ("a", 1), + ("b", 1), + ("b", 2), + ("c", 1), + ("d", 1), + ("c", 2) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, sum(product) / sum(sum(product)) over (partition by area) as c1 + |from windowData group by area, month order by month, c1 + """.stripMargin), + Seq( + ("d", 1.0), + ("a", 1.0), + ("b", 0.4666666666666667), + ("b", 0.5333333333333333), + ("c", 0.45), + ("c", 0.55) + ).map(i => Row(i._1, i._2))) + } + + // todo: fix this test case by reimplementing the function ResolveAggregateFunctions + ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product) over () as c from windowData + |where product > 3 group by area, product + |having avg(month) > 0 order by avg(month), product + """.stripMargin), + Seq( + ("a", 51), + ("b", 51), + ("b", 51), + ("c", 51), + ("c", 51), + ("d", 51) + ).map(i => Row(i._1, i._2))) + } + + test("window function: multiple window expressions in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + + val expected = + Row(1, 1, 1, 55, 1, 57) :: + Row(0, 2, 3, 55, 2, 60) :: + Row(1, 3, 6, 55, 4, 65) :: + Row(0, 4, 10, 55, 6, 71) :: + Row(1, 5, 15, 55, 9, 79) :: + Row(0, 6, 21, 55, 12, 88) :: + Row(1, 7, 28, 55, 16, 99) :: + Row(0, 8, 36, 55, 20, 111) :: + Row(1, 9, 45, 55, 25, 125) :: + Row(0, 10, 55, 55, 30, 140) :: Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) OVER w1 AS running_sum, + | sum(x) OVER w2 AS total_sum, + | sum(x) OVER w3 AS running_sum_per_y, + | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), + | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), + | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """.stripMargin) + + checkAnswer(actual, expected) + + spark.catalog.dropTempView("nums") + } + + test("SPARK-7595: Window will cause resolve failed with self join") { + checkAnswer(sql( + """ + |with + | v0 as (select 0 as key, 1 as value), + | v1 as (select key, count(value) over (partition by key) cnt_val from v0), + | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) + | select key, cnt_val from v2 order by key limit 1 + """.stripMargin), Row(0, 1)) + } + + test("SPARK-16633: lead/lag should return the default value if the offset row does not exist") { + checkAnswer(sql( + """ + |SELECT + | lag(123, 100, 321) OVER (ORDER BY id) as lag, + | lead(123, 100, 321) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id) tmp + """.stripMargin), + Row(321, 321)) + + checkAnswer(sql( + """ + |SELECT + | lag(123, 100, a) OVER (ORDER BY id) as lag, + | lead(123, 100, a) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id, 2 as a) tmp + """.stripMargin), + Row(2, 2)) + } + + test("lead/lag should respect null values") { + checkAnswer(sql( + """ + |SELECT + | b, + | lag(a, 1, 321) OVER (ORDER BY b) as lag, + | lead(a, 1, 321) OVER (ORDER BY b) as lead + |FROM (SELECT cast(null as int) as a, 1 as b + | UNION ALL + | select cast(null as int) as id, 2 as b) tmp + """.stripMargin), + Row(1, 321, null) :: Row(2, null, 321) :: Nil) + + checkAnswer(sql( + """ + |SELECT + | b, + | lag(a, 1, c) OVER (ORDER BY b) as lag, + | lead(a, 1, c) OVER (ORDER BY b) as lead + |FROM (SELECT cast(null as int) as a, 1 as b, 3 as c + | UNION ALL + | select cast(null as int) as id, 2 as b, 4 as c) tmp + """.stripMargin), + Row(1, 3, null) :: Row(2, null, 4) :: Nil) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4391d4a3/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala deleted file mode 100644 index 77e97df..0000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala +++ /dev/null @@ -1,370 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils - - -case class WindowData(month: Int, area: String, product: Int) - - -/** - * Test suite for SQL window functions. - */ -class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - import spark.implicits._ - - test("window function: udaf with aggregate expression") { - val data = Seq( - WindowData(1, "a", 5), - WindowData(2, "a", 6), - WindowData(3, "b", 7), - WindowData(4, "b", 8), - WindowData(5, "c", 9), - WindowData(6, "c", 10) - ) - sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") - - checkAnswer( - sql( - """ - |select area, sum(product), sum(sum(product)) over (partition by area) - |from windowData group by month, area - """.stripMargin), - Seq( - ("a", 5, 11), - ("a", 6, 11), - ("b", 7, 15), - ("b", 8, 15), - ("c", 9, 19), - ("c", 10, 19) - ).map(i => Row(i._1, i._2, i._3))) - - checkAnswer( - sql( - """ - |select area, sum(product) - 1, sum(sum(product)) over (partition by area) - |from windowData group by month, area - """.stripMargin), - Seq( - ("a", 4, 11), - ("a", 5, 11), - ("b", 6, 15), - ("b", 7, 15), - ("c", 8, 19), - ("c", 9, 19) - ).map(i => Row(i._1, i._2, i._3))) - - checkAnswer( - sql( - """ - |select area, sum(product), sum(product) / sum(sum(product)) over (partition by area) - |from windowData group by month, area - """.stripMargin), - Seq( - ("a", 5, 5d/11), - ("a", 6, 6d/11), - ("b", 7, 7d/15), - ("b", 8, 8d/15), - ("c", 10, 10d/19), - ("c", 9, 9d/19) - ).map(i => Row(i._1, i._2, i._3))) - - checkAnswer( - sql( - """ - |select area, sum(product), sum(product) / sum(sum(product) - 1) over (partition by area) - |from windowData group by month, area - """.stripMargin), - Seq( - ("a", 5, 5d/9), - ("a", 6, 6d/9), - ("b", 7, 7d/13), - ("b", 8, 8d/13), - ("c", 10, 10d/17), - ("c", 9, 9d/17) - ).map(i => Row(i._1, i._2, i._3))) - } - - test("window function: refer column in inner select block") { - val data = Seq( - WindowData(1, "a", 5), - WindowData(2, "a", 6), - WindowData(3, "b", 7), - WindowData(4, "b", 8), - WindowData(5, "c", 9), - WindowData(6, "c", 10) - ) - sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") - - checkAnswer( - sql( - """ - |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 - |from (select month, area, product, 1 as tmp1 from windowData) tmp - """.stripMargin), - Seq( - ("a", 2), - ("a", 3), - ("b", 2), - ("b", 3), - ("c", 2), - ("c", 3) - ).map(i => Row(i._1, i._2))) - } - - test("window function: partition and order expressions") { - val data = Seq( - WindowData(1, "a", 5), - WindowData(2, "a", 6), - WindowData(3, "b", 7), - WindowData(4, "b", 8), - WindowData(5, "c", 9), - WindowData(6, "c", 10) - ) - sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") - - checkAnswer( - sql( - """ - |select month, area, product, sum(product + 1) over (partition by 1 order by 2) - |from windowData - """.stripMargin), - Seq( - (1, "a", 5, 51), - (2, "a", 6, 51), - (3, "b", 7, 51), - (4, "b", 8, 51), - (5, "c", 9, 51), - (6, "c", 10, 51) - ).map(i => Row(i._1, i._2, i._3, i._4))) - - checkAnswer( - sql( - """ - |select month, area, product, sum(product) - |over (partition by month % 2 order by 10 - product) - |from windowData - """.stripMargin), - Seq( - (1, "a", 5, 21), - (2, "a", 6, 24), - (3, "b", 7, 16), - (4, "b", 8, 18), - (5, "c", 9, 9), - (6, "c", 10, 10) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("window function: distinct should not be silently ignored") { - val data = Seq( - WindowData(1, "a", 5), - WindowData(2, "a", 6), - WindowData(3, "b", 7), - WindowData(4, "b", 8), - WindowData(5, "c", 9), - WindowData(6, "c", 10) - ) - sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") - - val e = intercept[AnalysisException] { - sql( - """ - |select month, area, product, sum(distinct product + 1) over (partition by 1 order by 2) - |from windowData - """.stripMargin) - } - assert(e.getMessage.contains("Distinct window functions are not supported")) - } - - test("window function: expressions in arguments of a window functions") { - val data = Seq( - WindowData(1, "a", 5), - WindowData(2, "a", 6), - WindowData(3, "b", 7), - WindowData(4, "b", 8), - WindowData(5, "c", 9), - WindowData(6, "c", 10) - ) - sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") - - checkAnswer( - sql( - """ - |select month, area, month % 2, - |lag(product, 1 + 1, product) over (partition by month % 2 order by area) - |from windowData - """.stripMargin), - Seq( - (1, "a", 1, 5), - (2, "a", 0, 6), - (3, "b", 1, 7), - (4, "b", 0, 8), - (5, "c", 1, 5), - (6, "c", 0, 6) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - - test("window function: Sorting columns are not in Project") { - val data = Seq( - WindowData(1, "d", 10), - WindowData(2, "a", 6), - WindowData(3, "b", 7), - WindowData(4, "b", 8), - WindowData(5, "c", 9), - WindowData(6, "c", 11) - ) - sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") - - checkAnswer( - sql("select month, product, sum(product + 1) over() from windowData order by area"), - Seq( - (2, 6, 57), - (3, 7, 57), - (4, 8, 57), - (5, 9, 57), - (6, 11, 57), - (1, 10, 57) - ).map(i => Row(i._1, i._2, i._3))) - - checkAnswer( - sql( - """ - |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 - |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p - """.stripMargin), - Seq( - ("a", 2), - ("b", 2), - ("b", 3), - ("c", 2), - ("d", 2), - ("c", 3) - ).map(i => Row(i._1, i._2))) - - checkAnswer( - sql( - """ - |select area, rank() over (partition by area order by month) as c1 - |from windowData group by product, area, month order by product, area - """.stripMargin), - Seq( - ("a", 1), - ("b", 1), - ("b", 2), - ("c", 1), - ("d", 1), - ("c", 2) - ).map(i => Row(i._1, i._2))) - - checkAnswer( - sql( - """ - |select area, sum(product) / sum(sum(product)) over (partition by area) as c1 - |from windowData group by area, month order by month, c1 - """.stripMargin), - Seq( - ("d", 1.0), - ("a", 1.0), - ("b", 0.4666666666666667), - ("b", 0.5333333333333333), - ("c", 0.45), - ("c", 0.55) - ).map(i => Row(i._1, i._2))) - } - - // todo: fix this test case by reimplementing the function ResolveAggregateFunctions - ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { - val data = Seq( - WindowData(1, "d", 10), - WindowData(2, "a", 6), - WindowData(3, "b", 7), - WindowData(4, "b", 8), - WindowData(5, "c", 9), - WindowData(6, "c", 11) - ) - sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") - - checkAnswer( - sql( - """ - |select area, sum(product) over () as c from windowData - |where product > 3 group by area, product - |having avg(month) > 0 order by avg(month), product - """.stripMargin), - Seq( - ("a", 51), - ("b", 51), - ("b", 51), - ("c", 51), - ("c", 51), - ("d", 51) - ).map(i => Row(i._1, i._2))) - } - - test("window function: multiple window expressions in a single expression") { - val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") - nums.createOrReplaceTempView("nums") - - val expected = - Row(1, 1, 1, 55, 1, 57) :: - Row(0, 2, 3, 55, 2, 60) :: - Row(1, 3, 6, 55, 4, 65) :: - Row(0, 4, 10, 55, 6, 71) :: - Row(1, 5, 15, 55, 9, 79) :: - Row(0, 6, 21, 55, 12, 88) :: - Row(1, 7, 28, 55, 16, 99) :: - Row(0, 8, 36, 55, 20, 111) :: - Row(1, 9, 45, 55, 25, 125) :: - Row(0, 10, 55, 55, 30, 140) :: Nil - - val actual = sql( - """ - |SELECT - | y, - | x, - | sum(x) OVER w1 AS running_sum, - | sum(x) OVER w2 AS total_sum, - | sum(x) OVER w3 AS running_sum_per_y, - | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 - |FROM nums - |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), - | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), - | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) - """.stripMargin) - - checkAnswer(actual, expected) - - spark.catalog.dropTempView("nums") - } - - test("SPARK-7595: Window will cause resolve failed with self join") { - sql("SELECT * FROM src") // Force loading of src table. - - checkAnswer(sql( - """ - |with - | v1 as (select key, count(value) over (partition by key) cnt_val from src), - | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) - | select * from v2 order by key limit 1 - """.stripMargin), Row(0, 3)) - } -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org