This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 91c01dbad451 [SPARK-50196][CONNECT] Fix Python error context to use a
proper context
91c01dbad451 is described below
commit 91c01dbad451254e0e3f5efc3e33b7f5d612b5fa
Author: Takuya Ueshin <[email protected]>
AuthorDate: Sun Nov 3 14:14:02 2024 -0800
[SPARK-50196][CONNECT] Fix Python error context to use a proper context
### What changes were proposed in this pull request?
Fixes Python error context in Spark Connect to use a proper context.
### Why are the changes needed?
The Python error context in Spark Connect has a different context than
Spark Classic.
```py
spark.conf.set("spark.sql.ansi.enabled": True)
try:
df = spark.range(10)
df.withColumn("div_zero", (df.id / 0) * 10).collect()
except Exception as ee:
e = ee
e.getQueryContext()[0].fragment()
```
- classic
```
>>> e.getQueryContext()[0].fragment()
'__truediv__'
```
- connect
```
>>> e.getQueryContext()[0].fragment()
'__mul__'
```
### Does this PR introduce _any_ user-facing change?
The error context will be the same as Spark Classic.
### How was this patch tested?
Added the related test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48730 from ueshin/issues/SPARK-50196/dataframe_query_context.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/errors/utils.py | 6 ++--
python/pyspark/sql/connect/functions/builtin.py | 2 ++
python/pyspark/sql/functions/builtin.py | 2 ++
.../sql/tests/test_dataframe_query_context.py | 37 ++++++++++++++++++++++
.../sql/connect/planner/SparkConnectPlanner.scala | 18 +++++------
5 files changed, 53 insertions(+), 12 deletions(-)
diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py
index 1b7dbbb195ed..cbe5739204ac 100644
--- a/python/pyspark/errors/utils.py
+++ b/python/pyspark/errors/utils.py
@@ -33,6 +33,7 @@ from typing import (
Union,
TYPE_CHECKING,
overload,
+ cast,
)
import pyspark
from pyspark.errors.error_classes import ERROR_CLASSES_MAP
@@ -41,6 +42,7 @@ if TYPE_CHECKING:
from pyspark.sql import SparkSession
T = TypeVar("T")
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
_current_origin = threading.local()
@@ -225,7 +227,7 @@ def _capture_call_site(spark_session: "SparkSession",
depth: int) -> str:
return call_sites_str
-def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
+def _with_origin(func: FuncT) -> FuncT:
"""
A decorator to capture and provide the call site information to the server
side
when PySpark API functions are invoked.
@@ -269,7 +271,7 @@ def _with_origin(func: Callable[..., Any]) -> Callable[...,
Any]:
else:
return func(*args, **kwargs)
- return wrapper
+ return cast(FuncT, wrapper)
@overload
diff --git a/python/pyspark/sql/connect/functions/builtin.py
b/python/pyspark/sql/connect/functions/builtin.py
index b8bd0e9bf7fd..6f3ce942eb17 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -43,6 +43,7 @@ import sys
import numpy as np
from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors.utils import _with_origin
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
from pyspark.sql import Column
from pyspark.sql.connect.expressions import (
@@ -238,6 +239,7 @@ def _options_to_col(options: Mapping[str, Any]) -> Column:
# Normal Functions
+@_with_origin
def col(col: str) -> Column:
from pyspark.sql.connect.column import Column as ConnectColumn
diff --git a/python/pyspark/sql/functions/builtin.py
b/python/pyspark/sql/functions/builtin.py
index d6662421a79e..1e5349fb1649 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -40,6 +40,7 @@ from typing import (
)
from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors.utils import _with_origin
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
from pyspark.sql.types import (
@@ -293,6 +294,7 @@ def lit(col: Any) -> Column:
@_try_remote_functions
+@_with_origin
def col(col: str) -> Column:
"""
Returns a :class:`~pyspark.sql.Column` based on the given column name.
diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py
b/python/pyspark/sql/tests/test_dataframe_query_context.py
index bf0cc021ca77..edd769680c77 100644
--- a/python/pyspark/sql/tests/test_dataframe_query_context.py
+++ b/python/pyspark/sql/tests/test_dataframe_query_context.py
@@ -22,6 +22,7 @@ from pyspark.errors import (
QueryContextType,
NumberFormatException,
)
+from pyspark.sql import functions as sf
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
)
@@ -449,6 +450,42 @@ class DataFrameQueryContextTestsMixin:
query_context_type=None,
)
+ def test_query_context_complex(self):
+ with self.sql_conf({"spark.sql.ansi.enabled": True}):
+ # SQLQueryContext
+ with self.assertRaises(ArithmeticException) as pe:
+ self.spark.sql("select (10/0)*100").collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="DIVIDE_BY_ZERO",
+ messageParameters={"config": '"spark.sql.ansi.enabled"'},
+ query_context_type=QueryContextType.SQL,
+ )
+
+ # DataFrameQueryContext
+ df = self.spark.range(10)
+ with self.assertRaises(ArithmeticException) as pe:
+ df.withColumn("div_zero", (df.id / 0) * 10).collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="DIVIDE_BY_ZERO",
+ messageParameters={"config": '"spark.sql.ansi.enabled"'},
+ query_context_type=QueryContextType.DataFrame,
+ fragment="__truediv__",
+ )
+
+ def test_dataframe_query_context_col(self):
+ with self.assertRaises(AnalysisException) as pe:
+ self.spark.range(1).select(sf.col("id") + sf.col("idd")).show()
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ messageParameters={"objectName": "`idd`", "proposal": "`id`"},
+ query_context_type=QueryContextType.DataFrame,
+ fragment="col",
+ )
+
class DataFrameQueryContextTests(DataFrameQueryContextTestsMixin,
ReusedSQLTestCase):
pass
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a9d2bd482150..979fd83612e7 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -55,8 +55,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter,
Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment,
CoGroup, CollectMetrics, CommandResult, Deduplicate,
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except,
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith,
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions,
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias,
TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateSt [...]
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
-import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter,
StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
@@ -1510,14 +1509,13 @@ class SparkConnectPlanner(
def transformExpression(
exp: proto.Expression,
baseRelationOpt: Option[LogicalPlan]): Expression = if (exp.hasCommon) {
- try {
- val origin = exp.getCommon.getOrigin
- PySparkCurrentOrigin.set(
- origin.getPythonOrigin.getFragment,
- origin.getPythonOrigin.getCallSite)
- withOrigin { doTransformExpression(exp, baseRelationOpt) }
- } finally {
- PySparkCurrentOrigin.clear()
+ CurrentOrigin.withOrigin {
+ val pythonOrigin = exp.getCommon.getOrigin.getPythonOrigin
+ val pysparkErrorContext = (pythonOrigin.getFragment,
pythonOrigin.getCallSite)
+ val newOrigin = CurrentOrigin.get.copy(pysparkErrorContext =
Some(pysparkErrorContext))
+ CurrentOrigin.withOrigin(newOrigin) {
+ doTransformExpression(exp, baseRelationOpt)
+ }
}
} else {
doTransformExpression(exp, baseRelationOpt)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]