dtenedor commented on code in PR #52334:
URL: https://github.com/apache/spark/pull/52334#discussion_r2422359131


##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala:
##########
@@ -57,6 +68,51 @@ case class Origin(
  */
 trait WithOrigin {
   def origin: Origin
+
+  /**
+   * Update query contexts in this object with translated positions. Uses 
reflection to
+   * generically update any object with query context.
+   */
+  def updateQueryContext(translator: Array[QueryContext] => 
Array[QueryContext]): WithOrigin = {
+    try {
+      val thisClass = this.getClass
+
+      // Try to find query context using common method names
+      val contextMethodNames = Seq("getQueryContext", "context")
+
+      for (methodName <- contextMethodNames) {
+        try {
+          val getMethod = thisClass.getMethod(methodName)
+          val currentContexts = 
getMethod.invoke(this).asInstanceOf[Array[QueryContext]]
+          val translatedContexts = translator(currentContexts)
+
+          // Try to update the field in-place
+          val fieldName = if (methodName == "getQueryContext") "queryContext" 
else "context"
+          try {
+            val field = thisClass.getDeclaredField(fieldName)
+            field.setAccessible(true)
+            field.set(this, translatedContexts)
+            return this // Successfully updated in-place!

Review Comment:
   can you skip this and the `this` on L109 would work instead equivalently?



##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala:
##########
@@ -57,6 +68,51 @@ case class Origin(
  */
 trait WithOrigin {
   def origin: Origin
+
+  /**
+   * Update query contexts in this object with translated positions. Uses 
reflection to
+   * generically update any object with query context.
+   */
+  def updateQueryContext(translator: Array[QueryContext] => 
Array[QueryContext]): WithOrigin = {
+    try {
+      val thisClass = this.getClass
+
+      // Try to find query context using common method names
+      val contextMethodNames = Seq("getQueryContext", "context")
+
+      for (methodName <- contextMethodNames) {
+        try {
+          val getMethod = thisClass.getMethod(methodName)
+          val currentContexts = 
getMethod.invoke(this).asInstanceOf[Array[QueryContext]]
+          val translatedContexts = translator(currentContexts)
+
+          // Try to update the field in-place
+          val fieldName = if (methodName == "getQueryContext") "queryContext" 
else "context"
+          try {
+            val field = thisClass.getDeclaredField(fieldName)
+            field.setAccessible(true)
+            field.set(this, translatedContexts)
+            return this // Successfully updated in-place!
+          } catch {
+            case _: NoSuchFieldException | _: IllegalAccessException |
+                _: IllegalArgumentException =>
+            // Field update failed, continue to try other methods
+          }
+        } catch {
+          case _: NoSuchMethodException | _: IllegalAccessException |

Review Comment:
   would it be simpler and equivalent to just combine all these exceptions into 
a single try/catch block, to avoid the nesting of three of them?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala:
##########
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.trees.ParameterSubstitutionInfo
+import org.apache.spark.sql.catalyst.util.LiteralToSqlConverter
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * Handler for parameter substitution across different Spark SQL contexts.
+ *
+ * This class consolidates the common parameter handling logic used by 
SparkSqlParser,
+ * SparkConnectPlanner, and ExecuteImmediate. It provides a single, consistent 
API
+ * for all parameter substitution operations in Spark SQL.
+ *
+ * Key features:
+ * - Automatic parameter type detection (named vs positional)
+ * - Uses CompoundOrSingleStatement parsing for all SQL constructs
+ * - Consistent error handling and validation
+ * - Support for complex data types (arrays, maps, nested structures)
+ * - Thread-safe operations with position-aware error context
+ *
+ * The handler integrates with the parser through callback mechanisms stored in
+ * CurrentOrigin to ensure error positions are correctly mapped back to the 
original SQL text.
+ *
+ * @example Basic usage:
+ * {{{
+ * val handler = new ParameterHandler()
+ * val context = NamedParameterContext(Map("param1" -> Literal(42)))
+ * val result = handler.substituteParameters("SELECT :param1", context)
+ * // result: "SELECT 42"
+ * }}}
+ *
+ * @example Optional context:
+ * {{{
+ * val handler = new ParameterHandler()
+ * val context = Some(NamedParameterContext(Map("param1" -> Literal(42))))
+ * val result = handler.substituteParametersIfNeeded("SELECT :param1", context)
+ * // result: "SELECT 42"
+ * }}}
+ *
+ * @see [[SubstituteParamsParser]] for the underlying parameter substitution 
logic
+ */
+class ParameterHandler {
+
+  // Compiled regex pattern for efficient parameter marker detection.
+  private val parameterMarkerPattern = java.util.regex.Pattern.compile("[?:]")
+
+

Review Comment:
   ```suggestion
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala:
##########
@@ -49,10 +53,195 @@ class SparkSqlParser extends AbstractSqlParser {
   val astBuilder = new SparkSqlAstBuilder()
 
   private val substitutor = new VariableSubstitution()
+  private[execution] val parameterHandler = new ParameterHandler()
+
+  /**
+   * Parse SQL with explicit parameter context, avoiding thread-local usage.
+   * This is the preferred method for parsing SQL with parameters.
+   *
+   * @param command The SQL text to parse
+   * @param parameterContext The parameter context containing parameter values
+   * @param toResult Function to convert the parser result
+   * @return The parsed result
+   */
+  def parseWithParameters[T](
+      command: String,
+      parameterContext: ParameterContext)
+      (toResult: SqlBaseParser => T): T = {
+    parseInternal(command, Some(parameterContext), isTopLevel = true)(toResult)
+  }
+
+  /**
+   * Parse SQL plan with explicit parameter context, avoiding thread-local 
usage.
+   * This is the preferred method for parsing SQL plans with parameters.
+   *
+   * @param sqlText The SQL text to parse
+   * @param parameterContext The parameter context containing parameter values
+   * @return The parsed logical plan
+   */
+  override def parsePlanWithParameters(
+      sqlText: String,
+      parameterContext: ParameterContext): LogicalPlan = {
+    parseWithParameters(sqlText, parameterContext) { parser =>
+      val ctx = parser.compoundOrSingleStatement()
+      withErrorHandling(ctx, Some(sqlText)) {
+        astBuilder.visitCompoundOrSingleStatement(ctx) match {
+          case compoundBody: CompoundPlanStatement => compoundBody
+          case plan: LogicalPlan => plan
+          case _ =>
+            val position = Origin(None, None)
+            throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText, 
position)
+        }
+      }
+    }
+  }
 
   protected override def parse[T](command: String)(toResult: SqlBaseParser => 
T): T = {
-    super.parse(substitutor.substitute(command))(toResult)
+    parseInternal(command, None, isTopLevel = true)(toResult)
+  }
+
+  /**
+   * Internal parse method that handles both parameter substitution and 
regular parsing.
+   *
+   * @param command The SQL text to parse
+   * @param parameterContext Optional parameter context for parameter 
substitution
+   * @param isTopLevel Whether this is a top-level parse (vs identifier/data 
type parsing)
+   * @param toResult Function to convert the parser result
+   * @return The parsed result
+   */
+  private def parseInternal[T](
+      command: String,
+      parameterContext: Option[ParameterContext],
+      isTopLevel: Boolean)
+      (toResult: SqlBaseParser => T): T = {
+
+    // Clear any stale substitution info from previous parsing operations to 
prevent contamination.
+    if (isTopLevel) {
+      CurrentOrigin.set(CurrentOrigin.get.copy(parameterSubstitutionInfo = 
None))
+    }
+
+    // Step 1: Check if parameter substitution should occur.
+    val (paramSubstituted, substitutionOccurred, hasParameters) = 
parameterContext match {
+      case Some(context) if isTopLevel =>
+        if (SQLConf.get.legacyParameterSubstitutionConstantsOnly) {
+          // Legacy mode: skip parameter substitution but still detect 
parameters for context.
+          // Parameters detected but substitution skipped in legacy mode.
+          // Position mapping will be set up below.
+          (command, false, true)
+        } else {
+          // Modern mode: perform parameter substitution if parameters are 
present.
+          val substituted = substituteParametersOrSetupCallback(command, 
context)
+          (substituted, substituted != command, true) // Track if substitution 
occurred.
+        }
+      case _ =>
+        // No parameter context or not top-level - no parameter substitution.
+        (command, false, false)
+    }
+
+    // Step 2: Apply existing variable substitution.
+    val variableSubstituted = substitutor.substitute(paramSubstituted)
+
+    // Step 3: Set up origin with original SQL text for parameter-aware error 
reporting.
+    val currentOrigin = CurrentOrigin.get
+    val originToUse = if ((substitutionOccurred || hasParameters) && 
isTopLevel) {
+      // Parameter substitution occurred or parameters detected in legacy mode.
+      // Set original SQL text for accurate error position mapping.
+      // IMPORTANT: Preserve any substitution info set by parameter 
substitution.
+      val baseOrigin = currentOrigin.copy(
+        sqlText = Some(command), // Use original SQL text, not substituted.
+        startIndex = Some(0),
+        stopIndex = Some(command.length - 1)
+        // parameterSubstitutionInfo is preserved by copy().
+      )
+
+      // Set up identity substitution info for legacy mode parameter error 
reporting.
+      if (hasParameters && !substitutionOccurred) {
+        // Legacy mode - set up identity info for proper error context.
+        val identityInfo = ParameterSubstitutionInfo(
+          originalSql = command,
+          isIdentity = true,
+          positionMapper = None
+        )
+        baseOrigin.copy(parameterSubstitutionInfo = Some(identityInfo))
+      } else {
+        baseOrigin
+      }
+    } else {
+      // No substitution or nested call - use existing origin unchanged.
+      currentOrigin
+    }
+
+    CurrentOrigin.withOrigin(originToUse) {
+      super.parse(variableSubstituted)(toResult)
+    }
+  }
+
+  private def substituteParametersOrSetupCallback(
+      command: String,
+      context: ParameterContext): String = {
+
+    // Check legacy configuration - if true, skip parameter substitution 
during parsing.
+    if (SQLConf.get.legacyParameterSubstitutionConstantsOnly) {
+      // In legacy mode, set up identity mapping to ensure clean error context.
+      parameterHandler.setupSubstitutionContext(
+        command,
+        command, // Original = substituted in legacy mode.
+        PositionMapper.identity(command), // Identity mapper since no 
substitution.
+        true // isIdentity = true.
+      )
+      command
+    } else {
+      parameterHandler.substituteParameters(command, context)
+    }
+  }
+
+  /**
+   * Internal parse method for identifiers and data types that bypasses 
parameter logic.
+   * This ensures clean parsing without parameter substitution side effects.
+   */
+  private def parseIdentifierInternal[T](command: String)(toResult: 
SqlBaseParser => T): T = {
+    // Clear any stale substitution info to prevent contamination between 
tests/operations.
+    val currentOrigin = CurrentOrigin.get
+    if (currentOrigin.parameterSubstitutionInfo.isDefined) {
+      CurrentOrigin.set(currentOrigin.copy(parameterSubstitutionInfo = None))
+    }
+
+    val variableSubstituted = substitutor.substitute(command)
+    try {
+      super.parse(variableSubstituted)(toResult)
+    } finally {
+      // Ensure we clear any parameter substitution info that might have been 
set during parsing

Review Comment:
   ```suggestion
         // Ensure we clear any parameter substitution info that might have 
been set during parsing.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to