This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 91c01dbad451 [SPARK-50196][CONNECT] Fix Python error context to use a 
proper context
91c01dbad451 is described below

commit 91c01dbad451254e0e3f5efc3e33b7f5d612b5fa
Author: Takuya Ueshin <[email protected]>
AuthorDate: Sun Nov 3 14:14:02 2024 -0800

    [SPARK-50196][CONNECT] Fix Python error context to use a proper context
    
    ### What changes were proposed in this pull request?
    
    Fixes Python error context in Spark Connect to use a proper context.
    
    ### Why are the changes needed?
    
    The Python error context in Spark Connect has a different context than 
Spark Classic.
    
    ```py
    spark.conf.set("spark.sql.ansi.enabled": True)
    
    try:
      df = spark.range(10)
      df.withColumn("div_zero", (df.id / 0) * 10).collect()
    except Exception as ee:
      e = ee
    
    e.getQueryContext()[0].fragment()
    ```
    
    - classic
    
    ```
    >>> e.getQueryContext()[0].fragment()
    '__truediv__'
    ```
    
    - connect
    
    ```
    >>> e.getQueryContext()[0].fragment()
    '__mul__'
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    The error context will be the same as Spark Classic.
    
    ### How was this patch tested?
    
    Added the related test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48730 from ueshin/issues/SPARK-50196/dataframe_query_context.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/errors/utils.py                     |  6 ++--
 python/pyspark/sql/connect/functions/builtin.py    |  2 ++
 python/pyspark/sql/functions/builtin.py            |  2 ++
 .../sql/tests/test_dataframe_query_context.py      | 37 ++++++++++++++++++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 18 +++++------
 5 files changed, 53 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py
index 1b7dbbb195ed..cbe5739204ac 100644
--- a/python/pyspark/errors/utils.py
+++ b/python/pyspark/errors/utils.py
@@ -33,6 +33,7 @@ from typing import (
     Union,
     TYPE_CHECKING,
     overload,
+    cast,
 )
 import pyspark
 from pyspark.errors.error_classes import ERROR_CLASSES_MAP
@@ -41,6 +42,7 @@ if TYPE_CHECKING:
     from pyspark.sql import SparkSession
 
 T = TypeVar("T")
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
 
 _current_origin = threading.local()
 
@@ -225,7 +227,7 @@ def _capture_call_site(spark_session: "SparkSession", 
depth: int) -> str:
     return call_sites_str
 
 
-def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
+def _with_origin(func: FuncT) -> FuncT:
     """
     A decorator to capture and provide the call site information to the server 
side
     when PySpark API functions are invoked.
@@ -269,7 +271,7 @@ def _with_origin(func: Callable[..., Any]) -> Callable[..., 
Any]:
         else:
             return func(*args, **kwargs)
 
-    return wrapper
+    return cast(FuncT, wrapper)
 
 
 @overload
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index b8bd0e9bf7fd..6f3ce942eb17 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -43,6 +43,7 @@ import sys
 import numpy as np
 
 from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors.utils import _with_origin
 from pyspark.sql.dataframe import DataFrame as ParentDataFrame
 from pyspark.sql import Column
 from pyspark.sql.connect.expressions import (
@@ -238,6 +239,7 @@ def _options_to_col(options: Mapping[str, Any]) -> Column:
 # Normal Functions
 
 
+@_with_origin
 def col(col: str) -> Column:
     from pyspark.sql.connect.column import Column as ConnectColumn
 
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index d6662421a79e..1e5349fb1649 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -40,6 +40,7 @@ from typing import (
 )
 
 from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors.utils import _with_origin
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame as ParentDataFrame
 from pyspark.sql.types import (
@@ -293,6 +294,7 @@ def lit(col: Any) -> Column:
 
 
 @_try_remote_functions
+@_with_origin
 def col(col: str) -> Column:
     """
     Returns a :class:`~pyspark.sql.Column` based on the given column name.
diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py 
b/python/pyspark/sql/tests/test_dataframe_query_context.py
index bf0cc021ca77..edd769680c77 100644
--- a/python/pyspark/sql/tests/test_dataframe_query_context.py
+++ b/python/pyspark/sql/tests/test_dataframe_query_context.py
@@ -22,6 +22,7 @@ from pyspark.errors import (
     QueryContextType,
     NumberFormatException,
 )
+from pyspark.sql import functions as sf
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
 )
@@ -449,6 +450,42 @@ class DataFrameQueryContextTestsMixin:
                 query_context_type=None,
             )
 
+    def test_query_context_complex(self):
+        with self.sql_conf({"spark.sql.ansi.enabled": True}):
+            # SQLQueryContext
+            with self.assertRaises(ArithmeticException) as pe:
+                self.spark.sql("select (10/0)*100").collect()
+            self.check_error(
+                exception=pe.exception,
+                errorClass="DIVIDE_BY_ZERO",
+                messageParameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.SQL,
+            )
+
+            # DataFrameQueryContext
+            df = self.spark.range(10)
+            with self.assertRaises(ArithmeticException) as pe:
+                df.withColumn("div_zero", (df.id / 0) * 10).collect()
+            self.check_error(
+                exception=pe.exception,
+                errorClass="DIVIDE_BY_ZERO",
+                messageParameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                fragment="__truediv__",
+            )
+
+    def test_dataframe_query_context_col(self):
+        with self.assertRaises(AnalysisException) as pe:
+            self.spark.range(1).select(sf.col("id") + sf.col("idd")).show()
+
+        self.check_error(
+            exception=pe.exception,
+            errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
+            messageParameters={"objectName": "`idd`", "proposal": "`id`"},
+            query_context_type=QueryContextType.DataFrame,
+            fragment="col",
+        )
+
 
 class DataFrameQueryContextTests(DataFrameQueryContextTestsMixin, 
ReusedSQLTestCase):
     pass
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a9d2bd482150..979fd83612e7 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -55,8 +55,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, 
Inner, JoinType, L
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, 
CoGroup, CollectMetrics, CommandResult, Deduplicate, 
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, 
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, 
TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateSt [...]
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
-import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, 
StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
@@ -1510,14 +1509,13 @@ class SparkConnectPlanner(
   def transformExpression(
       exp: proto.Expression,
       baseRelationOpt: Option[LogicalPlan]): Expression = if (exp.hasCommon) {
-    try {
-      val origin = exp.getCommon.getOrigin
-      PySparkCurrentOrigin.set(
-        origin.getPythonOrigin.getFragment,
-        origin.getPythonOrigin.getCallSite)
-      withOrigin { doTransformExpression(exp, baseRelationOpt) }
-    } finally {
-      PySparkCurrentOrigin.clear()
+    CurrentOrigin.withOrigin {
+      val pythonOrigin = exp.getCommon.getOrigin.getPythonOrigin
+      val pysparkErrorContext = (pythonOrigin.getFragment, 
pythonOrigin.getCallSite)
+      val newOrigin = CurrentOrigin.get.copy(pysparkErrorContext = 
Some(pysparkErrorContext))
+      CurrentOrigin.withOrigin(newOrigin) {
+        doTransformExpression(exp, baseRelationOpt)
+      }
     }
   } else {
     doTransformExpression(exp, baseRelationOpt)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to