This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9b739d415cd5 [SPARK-49757][SQL] Support IDENTIFIER expression in SET
CATALOG statement
9b739d415cd5 is described below
commit 9b739d415cd51c8dd3f9332bae225196bab17d48
Author: Mikhail Nikoliukin <[email protected]>
AuthorDate: Fri Sep 27 21:48:35 2024 +0800
[SPARK-49757][SQL] Support IDENTIFIER expression in SET CATALOG statement
### What changes were proposed in this pull request?
This pr adds possibility to use `IDENTIFIER(...)` for a catalog name in
`SET CATALOG` statement.
For instance `SET CATALOG IDENTIFIER('test')` now works the same as `SET
CATALOG test`
### Why are the changes needed?
1. Consistency of API. It can be confusing for user that he can use
IDENTIFIER in some contexts but cannot for catalogs.
2. Parametrization. It allows user to write `SET CATALOG
IDENTIFIER(:user_data)` and doesn't worry about SQL injections.
### Does this PR introduce _any_ user-facing change?
Yes, now `SET CATALOG IDENTIFIER(...)` works. It can be used with any
string expressions and parametrization.
But multipart identifiers (like `IDENTIFIER('database.table')`) are banned
and will rise ParseException with new type
`INVALID_SQL_SYNTAX.MULTI_PART_CATALOG_NAME`. This restriction always has been
on grammar level, but now user can try to bind such identifier via parameters.
### How was this patch tested?
Unit tests with several new covering new behavior.
### Was this patch authored or co-authored using generative AI tooling?
Yes, some code suggestions
Generated-by: GitHub Copilot
Closes #48228 from mikhailnik-db/SPARK-49757.
Authored-by: Mikhail Nikoliukin <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 2 +-
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 8 ++++-
.../spark/sql/errors/QueryParsingErrors.scala | 14 +++++---
.../spark/sql/execution/SparkSqlParser.scala | 35 +++++++++++++-----
.../analyzer-results/identifier-clause.sql.out | 2 +-
.../sql-tests/results/identifier-clause.sql.out | 2 +-
.../spark/sql/connector/DataSourceV2SQLSuite.scala | 42 ++++++++++++++++++++++
.../spark/sql/errors/QueryParsingErrorsSuite.scala | 4 +--
.../sql/execution/command/DDLParserSuite.scala | 4 +--
9 files changed, 93 insertions(+), 20 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index e83202d9e5ee..3fcb53426ecc 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -3023,7 +3023,7 @@
},
"MULTI_PART_NAME" : {
"message" : [
- "<statement> with multiple part function name(<funcName>) is not
allowed."
+ "<statement> with multiple part name(<name>) is not allowed."
]
},
"OPTION_IS_INVALID" : {
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 094f7f5315b8..866634b04128 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -148,7 +148,7 @@ statement
| ctes? dmlStatementNoWith
#dmlStatement
| USE identifierReference #use
| USE namespace identifierReference
#useNamespace
- | SET CATALOG (errorCapturingIdentifier | stringLit)
#setCatalog
+ | SET CATALOG catalogIdentifierReference
#setCatalog
| CREATE namespace (IF errorCapturingNot EXISTS)? identifierReference
(commentSpec |
locationSpec |
@@ -594,6 +594,12 @@ identifierReference
| multipartIdentifier
;
+catalogIdentifierReference
+ : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN
+ | errorCapturingIdentifier
+ | stringLit
+ ;
+
queryOrganization
: (ORDER BY order+=sortItem (COMMA order+=sortItem)*)?
(CLUSTER BY clusterBy+=expression (COMMA clusterBy+=expression)*)?
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index b19607a28f06..b0743d6de477 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -621,9 +621,8 @@ private[sql] object QueryParsingErrors extends
DataTypeErrorsBase {
def unsupportedFunctionNameError(funcName: Seq[String], ctx:
ParserRuleContext): Throwable = {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
- messageParameters = Map(
- "statement" -> toSQLStmt("CREATE TEMPORARY FUNCTION"),
- "funcName" -> toSQLId(funcName)),
+ messageParameters =
+ Map("statement" -> toSQLStmt("CREATE TEMPORARY FUNCTION"), "name" ->
toSQLId(funcName)),
ctx)
}
@@ -665,7 +664,14 @@ private[sql] object QueryParsingErrors extends
DataTypeErrorsBase {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
messageParameters =
- Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "funcName" ->
toSQLId(name)),
+ Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "name" ->
toSQLId(name)),
+ ctx)
+ }
+
+ def invalidNameForSetCatalog(name: Seq[String], ctx: ParserRuleContext):
Throwable = {
+ new ParseException(
+ errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
+ messageParameters = Map("statement" -> toSQLStmt("SET CATALOG"), "name"
-> toSQLId(name)),
ctx)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index a8261e5d98ba..1c735154f25e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -27,7 +27,7 @@ import org.antlr.v4.runtime.tree.TerminalNode
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
PersistedView, SchemaEvolution, SchemaTypeEvolution, UnresolvedFunctionName,
UnresolvedIdentifier, UnresolvedNamespace}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
PersistedView, PlanWithUnresolvedIdentifier, SchemaEvolution,
SchemaTypeEvolution, UnresolvedFunctionName, UnresolvedIdentifier,
UnresolvedNamespace}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.parser._
@@ -67,6 +67,25 @@ class SparkSqlAstBuilder extends AstBuilder {
private val configValueDef = """([^;]*);*""".r
private val strLiteralDef = """(".*?[^\\]"|'.*?[^\\]'|[^ \n\r\t"']+)""".r
+ private def withCatalogIdentClause(
+ ctx: CatalogIdentifierReferenceContext,
+ builder: Seq[String] => LogicalPlan): LogicalPlan = {
+ val exprCtx = ctx.expression
+ if (exprCtx != null) {
+ // resolve later in analyzer
+ PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx)
}, Nil,
+ (ident, _) => builder(ident))
+ } else if (ctx.errorCapturingIdentifier() != null) {
+ // resolve immediately
+ builder.apply(Seq(ctx.errorCapturingIdentifier().getText))
+ } else if (ctx.stringLit() != null) {
+ // resolve immediately
+ builder.apply(Seq(string(visitStringLit(ctx.stringLit()))))
+ } else {
+ throw SparkException.internalError("Invalid catalog name")
+ }
+ }
+
/**
* Create a [[SetCommand]] logical plan.
*
@@ -276,13 +295,13 @@ class SparkSqlAstBuilder extends AstBuilder {
* Create a [[SetCatalogCommand]] logical command.
*/
override def visitSetCatalog(ctx: SetCatalogContext): LogicalPlan =
withOrigin(ctx) {
- if (ctx.errorCapturingIdentifier() != null) {
- SetCatalogCommand(ctx.errorCapturingIdentifier().getText)
- } else if (ctx.stringLit() != null) {
- SetCatalogCommand(string(visitStringLit(ctx.stringLit())))
- } else {
- throw SparkException.internalError("Invalid catalog name")
- }
+ withCatalogIdentClause(ctx.catalogIdentifierReference, identifiers => {
+ if (identifiers.size > 1) {
+ // can occur when user put multipart string in IDENTIFIER(...) clause
+ throw QueryParsingErrors.invalidNameForSetCatalog(identifiers, ctx)
+ }
+ SetCatalogCommand(identifiers.head)
+ })
}
/**
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
index f0bf8b883dd8..20e6ca1e6a2e 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
@@ -893,7 +893,7 @@ org.apache.spark.sql.catalyst.parser.ParseException
"errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
"sqlState" : "42000",
"messageParameters" : {
- "funcName" : "`default`.`myDoubleAvg`",
+ "name" : "`default`.`myDoubleAvg`",
"statement" : "DROP TEMPORARY FUNCTION"
},
"queryContext" : [ {
diff --git
a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
index 952fb8fdc2bd..596745b4ba5d 100644
--- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
@@ -1024,7 +1024,7 @@ org.apache.spark.sql.catalyst.parser.ParseException
"errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
"sqlState" : "42000",
"messageParameters" : {
- "funcName" : "`default`.`myDoubleAvg`",
+ "name" : "`default`.`myDoubleAvg`",
"statement" : "DROP TEMPORARY FUNCTION"
},
"queryContext" : [ {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index dac066bbef83..6b58d23e9260 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -2887,6 +2887,48 @@ class DataSourceV2SQLSuiteV1Filter
"config" -> "\"spark.sql.catalog.not_exist_catalog\""))
}
+ test("SPARK-49757: SET CATALOG statement with IDENTIFIER should work") {
+ val catalogManager = spark.sessionState.catalogManager
+ assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME)
+
+ sql("SET CATALOG IDENTIFIER('testcat')")
+ assert(catalogManager.currentCatalog.name() == "testcat")
+
+ spark.sql("SET CATALOG IDENTIFIER(:param)", Map("param" -> "testcat2"))
+ assert(catalogManager.currentCatalog.name() == "testcat2")
+
+ checkError(
+ exception = intercept[CatalogNotFoundException] {
+ sql("SET CATALOG IDENTIFIER('not_exist_catalog')")
+ },
+ condition = "CATALOG_NOT_FOUND",
+ parameters = Map(
+ "catalogName" -> "`not_exist_catalog`",
+ "config" -> "\"spark.sql.catalog.not_exist_catalog\"")
+ )
+ }
+
+ test("SPARK-49757: SET CATALOG statement with IDENTIFIER with multipart name
should fail") {
+ val catalogManager = spark.sessionState.catalogManager
+ assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME)
+
+ val sqlText = "SET CATALOG IDENTIFIER(:param)"
+ checkError(
+ exception = intercept[ParseException] {
+ spark.sql(sqlText, Map("param" -> "testcat.ns1"))
+ },
+ condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
+ parameters = Map(
+ "name" -> "`testcat`.`ns1`",
+ "statement" -> "SET CATALOG"
+ ),
+ context = ExpectedContext(
+ fragment = sqlText,
+ start = 0,
+ stop = 29)
+ )
+ }
+
test("SPARK-35973: ShowCatalogs") {
val schema = new StructType()
.add("catalog", StringType, nullable = false)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
index da7b6e7f63c8..666f85e19c1c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
@@ -334,7 +334,7 @@ class QueryParsingErrorsSuite extends QueryTest with
SharedSparkSession with SQL
sqlState = "42000",
parameters = Map(
"statement" -> "CREATE TEMPORARY FUNCTION",
- "funcName" -> "`ns`.`db`.`func`"),
+ "name" -> "`ns`.`db`.`func`"),
context = ExpectedContext(
fragment = sqlText,
start = 0,
@@ -367,7 +367,7 @@ class QueryParsingErrorsSuite extends QueryTest with
SharedSparkSession with SQL
sqlState = "42000",
parameters = Map(
"statement" -> "DROP TEMPORARY FUNCTION",
- "funcName" -> "`db`.`func`"),
+ "name" -> "`db`.`func`"),
context = ExpectedContext(
fragment = sqlText,
start = 0,
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
index 176eb7c29076..8b868c0e1723 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
@@ -688,7 +688,7 @@ class DDLParserSuite extends AnalysisTest with
SharedSparkSession {
checkError(
exception = parseException(sql1),
condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
- parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" ->
"`a`.`b`"),
+ parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" ->
"`a`.`b`"),
context = ExpectedContext(
fragment = sql1,
start = 0,
@@ -698,7 +698,7 @@ class DDLParserSuite extends AnalysisTest with
SharedSparkSession {
checkError(
exception = parseException(sql2),
condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
- parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" ->
"`a`.`b`"),
+ parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" ->
"`a`.`b`"),
context = ExpectedContext(
fragment = sql2,
start = 0,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]