This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d25525d876b [SPARK-46993][SQL] Fix constant folding for session 
variables
3d25525d876b is described below

commit 3d25525d876be9c5f3bd2dce917cb6bec10fb3e9
Author: Serge Rielau <se...@rielau.com>
AuthorDate: Thu Feb 8 16:24:03 2024 +0800

    [SPARK-46993][SQL] Fix constant folding for session variables
    
    ### What changes were proposed in this pull request?
    
    Remove the unconditional Alias node generation when resolving a variable 
reference.
    
    ### Why are the changes needed?
    
    An Alias that is not at the top level of an expression blocks constant 
folding.
    Constant folding in turn is a requirement for variables to be usable as an 
argument to numerous functions, such as from_json().
    It also has performance implications.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing regression tests in sql-session-variables.sql, added a test to 
validate the fix.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45059 from 
srielau/SPARK-46993-Fix-constant-folding-for-session-variables.
    
    Authored-by: Serge Rielau <se...@rielau.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/analysis/ColumnResolutionHelper.scala | 35 ++++++++++++----
 .../analyzer-results/sql-session-variables.sql.out | 43 ++++++++++++++++----
 .../sql-tests/inputs/sql-session-variables.sql     |  6 +++
 .../results/sql-session-variables.sql.out          | 46 ++++++++++++++++++----
 4 files changed, 109 insertions(+), 21 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 2472705d2f54..8ea50e2ceb65 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -312,14 +312,35 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
       }.map(e => Alias(e, nameParts.last)())
     }
 
-    e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, 
TEMP_RESOLVED_COLUMN)) {
-      case u: UnresolvedAttribute =>
-        resolve(u.nameParts).getOrElse(u)
-      // Re-resolves `TempResolvedColumn` as variable references if it has 
tried to be resolved with
-      // Aggregate but failed.
-      case t: TempResolvedColumn if t.hasTried =>
-        resolve(t.nameParts).getOrElse(t)
+    def innerResolve(e: Expression, isTopLevel: Boolean): Expression = 
withOrigin(e.origin) {
+      if (e.resolved || !e.containsAnyPattern(UNRESOLVED_ATTRIBUTE, 
TEMP_RESOLVED_COLUMN)) return e
+      val resolved = e match {
+        case u @ UnresolvedAttribute(nameParts) =>
+          val result = withPosition(u) {
+            resolve(nameParts).getOrElse(u) match {
+              // We trim unnecessary alias here. Note that, we cannot trim the 
alias at top-level,
+              case Alias(child, _) if !isTopLevel => child
+              case other => other
+            }
+          }
+          result
+
+        // Re-resolves `TempResolvedColumn` as variable references if it has 
tried to be
+        // resolved with Aggregate but failed.
+        case t: TempResolvedColumn if t.hasTried => withPosition(t) {
+          resolve(t.nameParts).getOrElse(t) match {
+            case _: UnresolvedAttribute => t
+            case other => other
+          }
+        }
+
+        case _ => e.mapChildren(innerResolve(_, isTopLevel = false))
+      }
+      resolved.copyTagsFrom(e)
+      resolved
     }
+
+    innerResolve(e, isTopLevel = true)
   }
 
   // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via 
`plan.child.output` if plan is an
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
index e75c7946ef76..6a6ffe85ad59 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
@@ -74,7 +74,7 @@ CreateVariable defaultvalueexpression(null, null), false
 -- !query
 SELECT 'Expect: INT, NULL', typeof(var1), var1
 -- !query analysis
-Project [Expect: INT, NULL AS Expect: INT, NULL#x, 
typeof(variablereference(system.session.var1=CAST(NULL AS INT))) AS 
typeof(variablereference(system.session.var1=CAST(NULL AS INT)) AS var1)#x, 
variablereference(system.session.var1=CAST(NULL AS INT)) AS var1#x]
+Project [Expect: INT, NULL AS Expect: INT, NULL#x, 
typeof(variablereference(system.session.var1=CAST(NULL AS INT))) AS 
typeof(variablereference(system.session.var1=CAST(NULL AS INT)))#x, 
variablereference(system.session.var1=CAST(NULL AS INT)) AS var1#x]
 +- OneRowRelation
 
 
@@ -88,7 +88,7 @@ CreateVariable defaultvalueexpression(null, null), true
 -- !query
 SELECT 'Expect: DOUBLE, NULL', typeof(var1), var1
 -- !query analysis
-Project [Expect: DOUBLE, NULL AS Expect: DOUBLE, NULL#x, 
typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE))) AS 
typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1)#x, 
variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1#x]
+Project [Expect: DOUBLE, NULL AS Expect: DOUBLE, NULL#x, 
typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE))) AS 
typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE)))#x, 
variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1#x]
 +- OneRowRelation
 
 
@@ -108,7 +108,7 @@ DECLARE OR REPLACE VARIABLE var1 TIMESTAMP
 -- !query
 SELECT 'Expect: TIMESTAMP, NULL', typeof(var1), var1
 -- !query analysis
-Project [Expect: TIMESTAMP, NULL AS Expect: TIMESTAMP, NULL#x, 
typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP))) AS 
typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS 
var1)#x, variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS 
var1#x]
+Project [Expect: TIMESTAMP, NULL AS Expect: TIMESTAMP, NULL#x, 
typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP))) AS 
typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)))#x, 
variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS var1#x]
 +- OneRowRelation
 
 
@@ -1969,21 +1969,21 @@ Project [variablereference(system.session.var1=1) AS 
var1#x]
 -- !query
 SELECT sum(var1) FROM VALUES(1)
 -- !query analysis
-Aggregate [sum(variablereference(system.session.var1=1)) AS 
sum(variablereference(system.session.var1=1) AS var1)#xL]
+Aggregate [sum(variablereference(system.session.var1=1)) AS 
sum(variablereference(system.session.var1=1))#xL]
 +- LocalRelation [col1#x]
 
 
 -- !query
 SELECT var1 + SUM(0) FROM VALUES(1)
 -- !query analysis
-Aggregate [(cast(variablereference(system.session.var1=1) as bigint) + sum(0)) 
AS (variablereference(system.session.var1=1) AS var1 + sum(0))#xL]
+Aggregate [(cast(variablereference(system.session.var1=1) as bigint) + sum(0)) 
AS (variablereference(system.session.var1=1) + sum(0))#xL]
 +- LocalRelation [col1#x]
 
 
 -- !query
 SELECT substr('12345', var1, 1)
 -- !query analysis
-Project [substr(12345, variablereference(system.session.var1=1), 1) AS 
substr(12345, variablereference(system.session.var1=1) AS var1, 1)#x]
+Project [substr(12345, variablereference(system.session.var1=1), 1) AS 
substr(12345, variablereference(system.session.var1=1), 1)#x]
 +- OneRowRelation
 
 
@@ -2031,7 +2031,7 @@ org.apache.spark.sql.catalyst.parser.ParseException
 -- !query
 SELECT array(1, 2, 4)[var1]
 -- !query analysis
-Project [array(1, 2, 4)[variablereference(system.session.var1=1)] AS array(1, 
2, 4)[variablereference(system.session.var1=1) AS var1]#x]
+Project [array(1, 2, 4)[variablereference(system.session.var1=1)] AS array(1, 
2, 4)[variablereference(system.session.var1=1)]#x]
 +- OneRowRelation
 
 
@@ -2126,6 +2126,35 @@ DROP VIEW IF EXISTS V
 DropTableCommand `spark_catalog`.`default`.`V`, true, true, false
 
 
+-- !query
+DROP TEMPORARY VARIABLE var1
+-- !query analysis
+DropVariable false
++- ResolvedIdentifier 
org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1
+
+
+-- !query
+SET VARIABLE title = 'variable references -- test constant folding'
+-- !query analysis
+SetVariable [variablereference(system.session.title='variable references -- 
prohibited')]
++- Project [variable references -- test constant folding AS title#x]
+   +- OneRowRelation
+
+
+-- !query
+DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'a INT'
+-- !query analysis
+CreateVariable defaultvalueexpression(cast(a INT as string), 'a INT'), true
++- ResolvedIdentifier 
org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1
+
+
+-- !query
+SELECT from_json('{"a": 1}', var1)
+-- !query analysis
+Project [from_json(StructField(a,IntegerType,true), {"a": 1}, 
Some(America/Los_Angeles)) AS from_json({"a": 1})#x]
++- OneRowRelation
+
+
 -- !query
 DROP TEMPORARY VARIABLE var1
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql 
b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql
index 53149a5e37b2..2dd205adfa04 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql
@@ -375,3 +375,9 @@ CREATE OR REPLACE VIEW v AS SELECT var1 AS c1;
 DROP VIEW IF EXISTS V;
 
 DROP TEMPORARY VARIABLE var1;
+
+SET VARIABLE title = 'variable references -- test constant folding';
+
+DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'a INT';
+SELECT from_json('{"a": 1}', var1);
+DROP TEMPORARY VARIABLE var1;
diff --git 
a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out 
b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out
index 7a5aa87d683a..67f867e25741 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out
@@ -82,7 +82,7 @@ struct<>
 -- !query
 SELECT 'Expect: INT, NULL', typeof(var1), var1
 -- !query schema
-struct<Expect: INT, 
NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS INT)) AS 
var1):string,var1:int>
+struct<Expect: INT, 
NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS 
INT))):string,var1:int>
 -- !query output
 Expect: INT, NULL      int     NULL
 
@@ -98,7 +98,7 @@ struct<>
 -- !query
 SELECT 'Expect: DOUBLE, NULL', typeof(var1), var1
 -- !query schema
-struct<Expect: DOUBLE, 
NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE)) 
AS var1):string,var1:double>
+struct<Expect: DOUBLE, 
NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS 
DOUBLE))):string,var1:double>
 -- !query output
 Expect: DOUBLE, NULL   double  NULL
 
@@ -122,7 +122,7 @@ struct<>
 -- !query
 SELECT 'Expect: TIMESTAMP, NULL', typeof(var1), var1
 -- !query schema
-struct<Expect: TIMESTAMP, 
NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS 
TIMESTAMP)) AS var1):string,var1:timestamp>
+struct<Expect: TIMESTAMP, 
NULL:string,typeof(variablereference(system.session.var1=CAST(NULL AS 
TIMESTAMP))):string,var1:timestamp>
 -- !query output
 Expect: TIMESTAMP, NULL        timestamp       NULL
 
@@ -2138,7 +2138,7 @@ struct<var1:int>
 -- !query
 SELECT sum(var1) FROM VALUES(1)
 -- !query schema
-struct<sum(variablereference(system.session.var1=1) AS var1):bigint>
+struct<sum(variablereference(system.session.var1=1)):bigint>
 -- !query output
 1
 
@@ -2146,7 +2146,7 @@ struct<sum(variablereference(system.session.var1=1) AS 
var1):bigint>
 -- !query
 SELECT var1 + SUM(0) FROM VALUES(1)
 -- !query schema
-struct<(variablereference(system.session.var1=1) AS var1 + sum(0)):bigint>
+struct<(variablereference(system.session.var1=1) + sum(0)):bigint>
 -- !query output
 1
 
@@ -2154,7 +2154,7 @@ struct<(variablereference(system.session.var1=1) AS var1 
+ sum(0)):bigint>
 -- !query
 SELECT substr('12345', var1, 1)
 -- !query schema
-struct<substr(12345, variablereference(system.session.var1=1) AS var1, 
1):string>
+struct<substr(12345, variablereference(system.session.var1=1), 1):string>
 -- !query output
 1
 
@@ -2202,7 +2202,7 @@ org.apache.spark.sql.catalyst.parser.ParseException
 -- !query
 SELECT array(1, 2, 4)[var1]
 -- !query schema
-struct<array(1, 2, 4)[variablereference(system.session.var1=1) AS var1]:int>
+struct<array(1, 2, 4)[variablereference(system.session.var1=1)]:int>
 -- !query output
 2
 
@@ -2303,3 +2303,35 @@ DROP TEMPORARY VARIABLE var1
 struct<>
 -- !query output
 
+
+
+-- !query
+SET VARIABLE title = 'variable references -- test constant folding'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'a INT'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT from_json('{"a": 1}', var1)
+-- !query schema
+struct<from_json({"a": 1}):struct<a:int>>
+-- !query output
+{"a":1}
+
+
+-- !query
+DROP TEMPORARY VARIABLE var1
+-- !query schema
+struct<>
+-- !query output
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to