This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3d25525d876b [SPARK-46993][SQL] Fix constant folding for session variables 3d25525d876b is described below commit 3d25525d876be9c5f3bd2dce917cb6bec10fb3e9 Author: Serge Rielau <se...@rielau.com> AuthorDate: Thu Feb 8 16:24:03 2024 +0800 [SPARK-46993][SQL] Fix constant folding for session variables ### What changes were proposed in this pull request? Remove the unconditional Alias node generation when resolving a variable reference. ### Why are the changes needed? An Alias that is not at the top level of an expression blocks constant folding. Constant folding in turn is a requirement for variables to be usable as an argument to numerous functions, such as from_json(). It also has performance implications. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing regression tests in sql-session-variables.sql, added a test to validate the fix. ### Was this patch authored or co-authored using generative AI tooling? No Closes #45059 from srielau/SPARK-46993-Fix-constant-folding-for-session-variables. Authored-by: Serge Rielau <se...@rielau.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/analysis/ColumnResolutionHelper.scala | 35 ++++++++++++---- .../analyzer-results/sql-session-variables.sql.out | 43 ++++++++++++++++---- .../sql-tests/inputs/sql-session-variables.sql | 6 +++ .../results/sql-session-variables.sql.out | 46 ++++++++++++++++++---- 4 files changed, 109 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 2472705d2f54..8ea50e2ceb65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -312,14 +312,35 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { }.map(e => Alias(e, nameParts.last)()) } - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { - case u: UnresolvedAttribute => - resolve(u.nameParts).getOrElse(u) - // Re-resolves `TempResolvedColumn` as variable references if it has tried to be resolved with - // Aggregate but failed. - case t: TempResolvedColumn if t.hasTried => - resolve(t.nameParts).getOrElse(t) + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { + if (e.resolved || !e.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) return e + val resolved = e match { + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolve(nameParts).getOrElse(u) match { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + case Alias(child, _) if !isTopLevel => child + case other => other + } + } + result + + // Re-resolves `TempResolvedColumn` as variable references if it has tried to be + // resolved with Aggregate but failed. + case t: TempResolvedColumn if t.hasTried => withPosition(t) { + resolve(t.nameParts).getOrElse(t) match { + case _: UnresolvedAttribute => t + case other => other + } + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + resolved.copyTagsFrom(e) + resolved } + + innerResolve(e, isTopLevel = true) } // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via `plan.child.output` if plan is an diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out index e75c7946ef76..6a6ffe85ad59 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -74,7 +74,7 @@ CreateVariable defaultvalueexpression(null, null), false -- !query SELECT 'Expect: INT, NULL', typeof(var1), var1 -- !query analysis -Project [Expect: INT, NULL AS Expect: INT, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS INT))) AS typeof(variablereference(system.session.var1=CAST(NULL AS INT)) AS var1)#x, variablereference(system.session.var1=CAST(NULL AS INT)) AS var1#x] +Project [Expect: INT, NULL AS Expect: INT, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS INT))) AS typeof(variablereference(system.session.var1=CAST(NULL AS INT)))#x, variablereference(system.session.var1=CAST(NULL AS INT)) AS var1#x] +- OneRowRelation @@ -88,7 +88,7 @@ CreateVariable defaultvalueexpression(null, null), true -- !query SELECT 'Expect: DOUBLE, NULL', typeof(var1), var1 -- !query analysis -Project [Expect: DOUBLE, NULL AS Expect: DOUBLE, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE))) AS typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1)#x, variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1#x] +Project [Expect: DOUBLE, NULL AS Expect: DOUBLE, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE))) AS typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE)))#x, variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1#x] +- OneRowRelation @@ -108,7 +108,7 @@ DECLARE OR REPLACE VARIABLE var1 TIMESTAMP -- !query SELECT 'Expect: TIMESTAMP, NULL', typeof(var1), var1 -- !query analysis -Project [Expect: TIMESTAMP, NULL AS Expect: TIMESTAMP, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP))) AS typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS var1)#x, variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS var1#x] +Project [Expect: TIMESTAMP, NULL AS Expect: TIMESTAMP, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP))) AS typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)))#x, variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS var1#x] +- OneRowRelation @@ -1969,21 +1969,21 @@ Project [variablereference(system.session.var1=1) AS var1#x] -- !query SELECT sum(var1) FROM VALUES(1) -- !query analysis -Aggregate [sum(variablereference(system.session.var1=1)) AS sum(variablereference(system.session.var1=1) AS var1)#xL] +Aggregate [sum(variablereference(system.session.var1=1)) AS sum(variablereference(system.session.var1=1))#xL] +- LocalRelation [col1#x] -- !query SELECT var1 + SUM(0) FROM VALUES(1) -- !query analysis -Aggregate [(cast(variablereference(system.session.var1=1) as bigint) + sum(0)) AS (variablereference(system.session.var1=1) AS var1 + sum(0))#xL] +Aggregate [(cast(variablereference(system.session.var1=1) as bigint) + sum(0)) AS (variablereference(system.session.var1=1) + sum(0))#xL] +- LocalRelation [col1#x] -- !query SELECT substr('12345', var1, 1) -- !query analysis -Project [substr(12345, variablereference(system.session.var1=1), 1) AS substr(12345, variablereference(system.session.var1=1) AS var1, 1)#x] +Project [substr(12345, variablereference(system.session.var1=1), 1) AS substr(12345, variablereference(system.session.var1=1), 1)#x] +- OneRowRelation @@ -2031,7 +2031,7 @@ org.apache.spark.sql.catalyst.parser.ParseException -- !query SELECT array(1, 2, 4)[var1] -- !query analysis -Project [array(1, 2, 4)[variablereference(system.session.var1=1)] AS array(1, 2, 4)[variablereference(system.session.var1=1) AS var1]#x] +Project [array(1, 2, 4)[variablereference(system.session.var1=1)] AS array(1, 2, 4)[variablereference(system.session.var1=1)]#x] +- OneRowRelation @@ -2126,6 +2126,35 @@ DROP VIEW IF EXISTS V DropTableCommand `spark_catalog`.`default`.`V`, true, true, false +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE title = 'variable references -- test constant folding' +-- !query analysis +SetVariable [variablereference(system.session.title='variable references -- prohibited')] ++- Project [variable references -- test constant folding AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'a INT' +-- !query analysis +CreateVariable defaultvalueexpression(cast(a INT as string), 'a INT'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT from_json('{"a": 1}', var1) +-- !query analysis +Project [from_json(StructField(a,IntegerType,true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x] ++- OneRowRelation + + -- !query DROP TEMPORARY VARIABLE var1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql index 53149a5e37b2..2dd205adfa04 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql @@ -375,3 +375,9 @@ CREATE OR REPLACE VIEW v AS SELECT var1 AS c1; DROP VIEW IF EXISTS V; DROP TEMPORARY VARIABLE var1; + +SET VARIABLE title = 'variable references -- test constant folding'; + +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'a INT'; +SELECT from_json('{"a": 1}', var1); +DROP TEMPORARY VARIABLE var1; diff --git a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out index 7a5aa87d683a..67f867e25741 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out @@ -82,7 +82,7 @@ struct<> -- !query SELECT 'Expect: INT, NULL', typeof(var1), var1 -- !query schema -struct<Expect: INT, NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS INT)) AS var1):string,var1:int> +struct<Expect: INT, NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS INT))):string,var1:int> -- !query output Expect: INT, NULL int NULL @@ -98,7 +98,7 @@ struct<> -- !query SELECT 'Expect: DOUBLE, NULL', typeof(var1), var1 -- !query schema -struct<Expect: DOUBLE, NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1):string,var1:double> +struct<Expect: DOUBLE, NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE))):string,var1:double> -- !query output Expect: DOUBLE, NULL double NULL @@ -122,7 +122,7 @@ struct<> -- !query SELECT 'Expect: TIMESTAMP, NULL', typeof(var1), var1 -- !query schema -struct<Expect: TIMESTAMP, NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS var1):string,var1:timestamp> +struct<Expect: TIMESTAMP, NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP))):string,var1:timestamp> -- !query output Expect: TIMESTAMP, NULL timestamp NULL @@ -2138,7 +2138,7 @@ struct<var1:int> -- !query SELECT sum(var1) FROM VALUES(1) -- !query schema -struct<sum(variablereference(system.session.var1=1) AS var1):bigint> +struct<sum(variablereference(system.session.var1=1)):bigint> -- !query output 1 @@ -2146,7 +2146,7 @@ struct<sum(variablereference(system.session.var1=1) AS var1):bigint> -- !query SELECT var1 + SUM(0) FROM VALUES(1) -- !query schema -struct<(variablereference(system.session.var1=1) AS var1 + sum(0)):bigint> +struct<(variablereference(system.session.var1=1) + sum(0)):bigint> -- !query output 1 @@ -2154,7 +2154,7 @@ struct<(variablereference(system.session.var1=1) AS var1 + sum(0)):bigint> -- !query SELECT substr('12345', var1, 1) -- !query schema -struct<substr(12345, variablereference(system.session.var1=1) AS var1, 1):string> +struct<substr(12345, variablereference(system.session.var1=1), 1):string> -- !query output 1 @@ -2202,7 +2202,7 @@ org.apache.spark.sql.catalyst.parser.ParseException -- !query SELECT array(1, 2, 4)[var1] -- !query schema -struct<array(1, 2, 4)[variablereference(system.session.var1=1) AS var1]:int> +struct<array(1, 2, 4)[variablereference(system.session.var1=1)]:int> -- !query output 2 @@ -2303,3 +2303,35 @@ DROP TEMPORARY VARIABLE var1 struct<> -- !query output + + +-- !query +SET VARIABLE title = 'variable references -- test constant folding' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'a INT' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT from_json('{"a": 1}', var1) +-- !query schema +struct<from_json({"a": 1}):struct<a:int>> +-- !query output +{"a":1} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org