This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b739d415cd5 [SPARK-49757][SQL] Support IDENTIFIER expression in SET 
CATALOG statement
9b739d415cd5 is described below

commit 9b739d415cd51c8dd3f9332bae225196bab17d48
Author: Mikhail Nikoliukin <[email protected]>
AuthorDate: Fri Sep 27 21:48:35 2024 +0800

    [SPARK-49757][SQL] Support IDENTIFIER expression in SET CATALOG statement
    
    ### What changes were proposed in this pull request?
    
    This pr adds possibility to use `IDENTIFIER(...)` for a catalog name in 
`SET CATALOG` statement.
    For instance `SET CATALOG IDENTIFIER('test')` now works the same as `SET 
CATALOG test`
    
    ### Why are the changes needed?
    
    1. Consistency of API. It can be confusing for user that he can use 
IDENTIFIER in some contexts but cannot for catalogs.
    2. Parametrization. It allows user to write `SET CATALOG 
IDENTIFIER(:user_data)` and doesn't worry about SQL injections.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, now `SET CATALOG IDENTIFIER(...)` works. It can be used with any 
string expressions and parametrization.
    But multipart identifiers (like `IDENTIFIER('database.table')`) are banned 
and will rise ParseException with new type 
`INVALID_SQL_SYNTAX.MULTI_PART_CATALOG_NAME`. This restriction always has been 
on grammar level, but now user can try to bind such identifier via parameters.
    
    ### How was this patch tested?
    
    Unit tests with several new covering new behavior.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, some code suggestions
    Generated-by: GitHub Copilot
    
    Closes #48228 from mikhailnik-db/SPARK-49757.
    
    Authored-by: Mikhail Nikoliukin <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  2 +-
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |  8 ++++-
 .../spark/sql/errors/QueryParsingErrors.scala      | 14 +++++---
 .../spark/sql/execution/SparkSqlParser.scala       | 35 +++++++++++++-----
 .../analyzer-results/identifier-clause.sql.out     |  2 +-
 .../sql-tests/results/identifier-clause.sql.out    |  2 +-
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 42 ++++++++++++++++++++++
 .../spark/sql/errors/QueryParsingErrorsSuite.scala |  4 +--
 .../sql/execution/command/DDLParserSuite.scala     |  4 +--
 9 files changed, 93 insertions(+), 20 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index e83202d9e5ee..3fcb53426ecc 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -3023,7 +3023,7 @@
       },
       "MULTI_PART_NAME" : {
         "message" : [
-          "<statement> with multiple part function name(<funcName>) is not 
allowed."
+          "<statement> with multiple part name(<name>) is not allowed."
         ]
       },
       "OPTION_IS_INVALID" : {
diff --git 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 094f7f5315b8..866634b04128 100644
--- 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -148,7 +148,7 @@ statement
     | ctes? dmlStatementNoWith                                         
#dmlStatement
     | USE identifierReference                                          #use
     | USE namespace identifierReference                                
#useNamespace
-    | SET CATALOG (errorCapturingIdentifier | stringLit)                  
#setCatalog
+    | SET CATALOG catalogIdentifierReference                           
#setCatalog
     | CREATE namespace (IF errorCapturingNot EXISTS)? identifierReference
         (commentSpec |
          locationSpec |
@@ -594,6 +594,12 @@ identifierReference
     | multipartIdentifier
     ;
 
+catalogIdentifierReference
+    : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN
+    | errorCapturingIdentifier
+    | stringLit
+    ;
+
 queryOrganization
     : (ORDER BY order+=sortItem (COMMA order+=sortItem)*)?
       (CLUSTER BY clusterBy+=expression (COMMA clusterBy+=expression)*)?
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index b19607a28f06..b0743d6de477 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -621,9 +621,8 @@ private[sql] object QueryParsingErrors extends 
DataTypeErrorsBase {
   def unsupportedFunctionNameError(funcName: Seq[String], ctx: 
ParserRuleContext): Throwable = {
     new ParseException(
       errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
-      messageParameters = Map(
-        "statement" -> toSQLStmt("CREATE TEMPORARY FUNCTION"),
-        "funcName" -> toSQLId(funcName)),
+      messageParameters =
+        Map("statement" -> toSQLStmt("CREATE TEMPORARY FUNCTION"), "name" -> 
toSQLId(funcName)),
       ctx)
   }
 
@@ -665,7 +664,14 @@ private[sql] object QueryParsingErrors extends 
DataTypeErrorsBase {
     new ParseException(
       errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
       messageParameters =
-        Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "funcName" -> 
toSQLId(name)),
+        Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "name" -> 
toSQLId(name)),
+      ctx)
+  }
+
+  def invalidNameForSetCatalog(name: Seq[String], ctx: ParserRuleContext): 
Throwable = {
+    new ParseException(
+      errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
+      messageParameters = Map("statement" -> toSQLStmt("SET CATALOG"), "name" 
-> toSQLId(name)),
       ctx)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index a8261e5d98ba..1c735154f25e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -27,7 +27,7 @@ import org.antlr.v4.runtime.tree.TerminalNode
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
PersistedView, SchemaEvolution, SchemaTypeEvolution, UnresolvedFunctionName, 
UnresolvedIdentifier, UnresolvedNamespace}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
PersistedView, PlanWithUnresolvedIdentifier, SchemaEvolution, 
SchemaTypeEvolution, UnresolvedFunctionName, UnresolvedIdentifier, 
UnresolvedNamespace}
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
 import org.apache.spark.sql.catalyst.parser._
@@ -67,6 +67,25 @@ class SparkSqlAstBuilder extends AstBuilder {
   private val configValueDef = """([^;]*);*""".r
   private val strLiteralDef = """(".*?[^\\]"|'.*?[^\\]'|[^ \n\r\t"']+)""".r
 
+  private def withCatalogIdentClause(
+      ctx: CatalogIdentifierReferenceContext,
+      builder: Seq[String] => LogicalPlan): LogicalPlan = {
+    val exprCtx = ctx.expression
+    if (exprCtx != null) {
+      // resolve later in analyzer
+      PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) 
}, Nil,
+        (ident, _) => builder(ident))
+    } else if (ctx.errorCapturingIdentifier() != null) {
+      // resolve immediately
+      builder.apply(Seq(ctx.errorCapturingIdentifier().getText))
+    } else if (ctx.stringLit() != null) {
+      // resolve immediately
+      builder.apply(Seq(string(visitStringLit(ctx.stringLit()))))
+    } else {
+      throw SparkException.internalError("Invalid catalog name")
+    }
+  }
+
   /**
    * Create a [[SetCommand]] logical plan.
    *
@@ -276,13 +295,13 @@ class SparkSqlAstBuilder extends AstBuilder {
    * Create a [[SetCatalogCommand]] logical command.
    */
   override def visitSetCatalog(ctx: SetCatalogContext): LogicalPlan = 
withOrigin(ctx) {
-    if (ctx.errorCapturingIdentifier() != null) {
-      SetCatalogCommand(ctx.errorCapturingIdentifier().getText)
-    } else if (ctx.stringLit() != null) {
-      SetCatalogCommand(string(visitStringLit(ctx.stringLit())))
-    } else {
-      throw SparkException.internalError("Invalid catalog name")
-    }
+    withCatalogIdentClause(ctx.catalogIdentifierReference, identifiers => {
+      if (identifiers.size > 1) {
+        // can occur when user put multipart string in IDENTIFIER(...) clause
+        throw QueryParsingErrors.invalidNameForSetCatalog(identifiers, ctx)
+      }
+      SetCatalogCommand(identifiers.head)
+    })
   }
 
   /**
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
index f0bf8b883dd8..20e6ca1e6a2e 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
@@ -893,7 +893,7 @@ org.apache.spark.sql.catalyst.parser.ParseException
   "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
   "sqlState" : "42000",
   "messageParameters" : {
-    "funcName" : "`default`.`myDoubleAvg`",
+    "name" : "`default`.`myDoubleAvg`",
     "statement" : "DROP TEMPORARY FUNCTION"
   },
   "queryContext" : [ {
diff --git 
a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out 
b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
index 952fb8fdc2bd..596745b4ba5d 100644
--- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
@@ -1024,7 +1024,7 @@ org.apache.spark.sql.catalyst.parser.ParseException
   "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
   "sqlState" : "42000",
   "messageParameters" : {
-    "funcName" : "`default`.`myDoubleAvg`",
+    "name" : "`default`.`myDoubleAvg`",
     "statement" : "DROP TEMPORARY FUNCTION"
   },
   "queryContext" : [ {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index dac066bbef83..6b58d23e9260 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -2887,6 +2887,48 @@ class DataSourceV2SQLSuiteV1Filter
         "config" -> "\"spark.sql.catalog.not_exist_catalog\""))
   }
 
+  test("SPARK-49757: SET CATALOG statement with IDENTIFIER should work") {
+    val catalogManager = spark.sessionState.catalogManager
+    assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME)
+
+    sql("SET CATALOG IDENTIFIER('testcat')")
+    assert(catalogManager.currentCatalog.name() == "testcat")
+
+    spark.sql("SET CATALOG IDENTIFIER(:param)", Map("param" -> "testcat2"))
+    assert(catalogManager.currentCatalog.name() == "testcat2")
+
+    checkError(
+      exception = intercept[CatalogNotFoundException] {
+        sql("SET CATALOG IDENTIFIER('not_exist_catalog')")
+      },
+      condition = "CATALOG_NOT_FOUND",
+      parameters = Map(
+        "catalogName" -> "`not_exist_catalog`",
+        "config" -> "\"spark.sql.catalog.not_exist_catalog\"")
+    )
+  }
+
+  test("SPARK-49757: SET CATALOG statement with IDENTIFIER with multipart name 
should fail") {
+    val catalogManager = spark.sessionState.catalogManager
+    assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME)
+
+    val sqlText = "SET CATALOG IDENTIFIER(:param)"
+    checkError(
+      exception = intercept[ParseException] {
+        spark.sql(sqlText, Map("param" -> "testcat.ns1"))
+      },
+      condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
+      parameters = Map(
+        "name" -> "`testcat`.`ns1`",
+        "statement" -> "SET CATALOG"
+      ),
+      context = ExpectedContext(
+        fragment = sqlText,
+        start = 0,
+        stop = 29)
+    )
+  }
+
   test("SPARK-35973: ShowCatalogs") {
     val schema = new StructType()
       .add("catalog", StringType, nullable = false)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
index da7b6e7f63c8..666f85e19c1c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
@@ -334,7 +334,7 @@ class QueryParsingErrorsSuite extends QueryTest with 
SharedSparkSession with SQL
       sqlState = "42000",
       parameters = Map(
         "statement" -> "CREATE TEMPORARY FUNCTION",
-        "funcName" -> "`ns`.`db`.`func`"),
+        "name" -> "`ns`.`db`.`func`"),
       context = ExpectedContext(
         fragment = sqlText,
         start = 0,
@@ -367,7 +367,7 @@ class QueryParsingErrorsSuite extends QueryTest with 
SharedSparkSession with SQL
       sqlState = "42000",
       parameters = Map(
         "statement" -> "DROP TEMPORARY FUNCTION",
-        "funcName" -> "`db`.`func`"),
+        "name" -> "`db`.`func`"),
       context = ExpectedContext(
         fragment = sqlText,
         start = 0,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
index 176eb7c29076..8b868c0e1723 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
@@ -688,7 +688,7 @@ class DDLParserSuite extends AnalysisTest with 
SharedSparkSession {
     checkError(
       exception = parseException(sql1),
       condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
-      parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> 
"`a`.`b`"),
+      parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> 
"`a`.`b`"),
       context = ExpectedContext(
         fragment = sql1,
         start = 0,
@@ -698,7 +698,7 @@ class DDLParserSuite extends AnalysisTest with 
SharedSparkSession {
     checkError(
       exception = parseException(sql2),
       condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
-      parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> 
"`a`.`b`"),
+      parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> 
"`a`.`b`"),
       context = ExpectedContext(
         fragment = sql2,
         start = 0,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to