cloud-fan commented on code in PR #52765:
URL: https://github.com/apache/spark/pull/52765#discussion_r2518568576


##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala:
##########
@@ -161,11 +201,112 @@ class DataTypeAstBuilder extends 
SqlBaseParserBaseVisitor[AnyRef] {
   }
 
   /**
-   * Create a multi-part identifier.
+   * Parse a string into a multi-part identifier. Subclasses should override 
this method to
+   * provide proper multi-part identifier parsing with access to a full SQL 
parser.
+   *
+   * For example, in AstBuilder, this would parse "`catalog`.`schema`.`table`" 
into Seq("catalog",
+   * "schema", "table").
+   *
+   * The base implementation fails with an assertion to catch cases where 
multi-part identifiers
+   * are used without a proper parser implementation.
+   *
+   * @param identifier
+   *   The identifier string to parse, potentially containing dots and 
backticks.
+   * @return
+   *   Sequence of identifier parts.
+   */
+  protected def parseMultipartIdentifier(identifier: String): Seq[String] = {
+    throw SparkException.internalError(
+      "parseMultipartIdentifier must be overridden by subclasses. " +
+        s"Attempted to parse: $identifier")
+  }
+
+  /**
+   * Get the identifier parts from a context, handling both regular 
identifiers and
+   * IDENTIFIER('literal'). This method is used to support identifier-lite 
syntax where
+   * IDENTIFIER('string') is folded at parse time. For qualified identifiers 
like
+   * IDENTIFIER('`catalog`.`schema`'), this will parse the string and return 
multiple parts.
+   *
+   * Subclasses should override this method to provide actual parsing logic.
+   */
+  protected def getIdentifierParts(ctx: ParserRuleContext): Seq[String] = {
+    ctx match {
+      case idCtx: IdentifierContext =>
+        // identifier can be either strictIdentifier or strictNonReserved.
+        // Recursively process the strictIdentifier.
+        
Option(idCtx.strictIdentifier()).map(getIdentifierParts).getOrElse(Seq(ctx.getText))
+
+      case idLitCtx: IdentifierLiteralContext =>
+        // For IDENTIFIER('literal') in strictIdentifier.
+        val literalValue = string(visitStringLit(idLitCtx.stringLit()))
+        // Parse the string to handle qualified identifiers like 
"`cat`.`schema`".
+        parseMultipartIdentifier(literalValue)
+
+      case idLitCtx: IdentifierLiteralWithExtraContext =>
+        // For IDENTIFIER('literal') in errorCapturingIdentifier.
+        val literalValue = string(visitStringLit(idLitCtx.stringLit()))
+        // Parse the string to handle qualified identifiers like 
"`cat`.`schema`".
+        parseMultipartIdentifier(literalValue)
+
+      case base: ErrorCapturingIdentifierBaseContext =>
+        // Regular identifier with errorCapturingIdentifierExtra.
+        // Need to recursively handle identifier which might itself be 
IDENTIFIER('literal').
+        Option(base.identifier())
+          .flatMap(id => Option(id.strictIdentifier()).map(getIdentifierParts))
+          .getOrElse(Seq(ctx.getText))
+
+      case _ =>
+        // For regular identifiers, just return the text as a single part.
+        Seq(ctx.getText)
+    }
+  }
+
+  /**
+   * Get the text of a SINGLE identifier, handling both regular identifiers and
+   * IDENTIFIER('literal'). This method REQUIRES that the identifier be 
unqualified (single part
+   * only). If IDENTIFIER('qualified.name') is used where a single identifier 
is required, this
+   * will error.
+   */
+  protected def getIdentifierText(ctx: ParserRuleContext): String = {
+    val parts = getIdentifierParts(ctx)
+    if (parts.size > 1) {
+      throw new ParseException(
+        errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS",
+        messageParameters = Map("identifier" -> toSQLId(parts), "limit" -> 
"1"),
+        ctx)
+    }
+    parts.head
+  }
+
+  /**
+   * Create a multi-part identifier. Handles identifier-lite with qualified 
identifiers like
+   * IDENTIFIER('`cat`.`schema`').table
    */
   override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): 
Seq[String] =
     withOrigin(ctx) {
-      ctx.parts.asScala.map(_.getText).toSeq
+      ctx.parts.asScala.flatMap { part =>
+        // Each part is an errorCapturingIdentifier, which can be either:
+        // 1. identifier errorCapturingIdentifierExtra (regular path) - 
labeled as
+        //    #errorCapturingIdentifierBase
+        // 2. IDENTIFIER_KW LEFT_PAREN stringLit RIGHT_PAREN 
errorCapturingIdentifierExtra
+        //    (identifier-lite path) - labeled as #identifierLiteralWithExtra
+        part match {
+          case idLitWithExtra: IdentifierLiteralWithExtraContext =>
+            // This is identifier-lite: IDENTIFIER('string')
+            getIdentifierParts(idLitWithExtra)
+          case base: ErrorCapturingIdentifierBaseContext =>

Review Comment:
   these two contexts are both handled by `getIdentifierParts`, shall we just 
do `ctx.parts.asScala.flatMap(getIdentifierParts)`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to