This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 86ae0d2bc198 [SPARK-47274][PYTHON][SQL] Provide more useful context for PySpark DataFrame API errors 86ae0d2bc198 is described below commit 86ae0d2bc19832f5bf5d872491cdede800427691 Author: Haejoon Lee <haejoon....@databricks.com> AuthorDate: Thu Apr 11 09:41:31 2024 +0800 [SPARK-47274][PYTHON][SQL] Provide more useful context for PySpark DataFrame API errors ### What changes were proposed in this pull request? This PR introduces an enhancement to the error messages generated by PySpark's DataFrame API, adding detailed context about the location within the user's PySpark code where the error occurred. This directly adds a PySpark user call site information into `DataFrameQueryContext` added from https://github.com/apache/spark/pull/43334, aiming to provide PySpark users with the same level of detailed error context for better usability and debugging efficiency for DataFrame APIs. This PR also introduces `QueryContext.pysparkCallSite` and `QueryContext.pysparkFragment` to get a PySpark information from the query context easily. This PR also enhances the functionality of `check_error` so that it can test the query context if it exists. ### Why are the changes needed? To improve a debuggability. Errors originating from PySpark operations can be difficult to debug with limited context in the error messages. While improvements on the JVM side have been made to offer detailed error contexts, PySpark errors often lack this level of detail. ### Does this PR introduce _any_ user-facing change? No API changes, but error messages will include a reference to the exact line of user code that triggered the error, in addition to the existing descriptive error message. For example, consider the following PySpark code snippet that triggers a `DIVIDE_BY_ZERO` error: ```python 1 spark.conf.set("spark.sql.ansi.enabled", True) 2 3 df = spark.range(10) 4 df.select(df.id / 0).show() ``` **Before:** ``` pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. SQLSTATE: 22012 == DataFrame == "divide" was called from java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ``` **After:** ``` pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. SQLSTATE: 22012 == DataFrame == "divide" was called from /.../spark/python/test_pyspark_error.py:4 ``` Now the error message points out the exact problematic code path with file name and line number that user writes. ## Points to the actual problem site instead of the site where the action was called Even when action calling after multiple transform operations are mixed, the exact problematic site can be provided to the user: **In:** ```python 1 spark.conf.set("spark.sql.ansi.enabled", True) 2 df = spark.range(10) 3 4 df1 = df.withColumn("div_ten", df.id / 10) 5 df2 = df1.withColumn("plus_four", df.id + 4) 6 7 # This is problematic divide operation that occurs DIVIDE_BY_ZERO. 8 df3 = df2.withColumn("div_zero", df.id / 0) 9 df4 = df3.withColumn("minus_five", df.id / 5) 10 11 df4.collect() ``` **Out:** ``` pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. SQLSTATE: 22012 == DataFrame == "divide" was called from /.../spark/python/test_pyspark_error.py:8 ``` ### How was this patch tested? Added UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45377 from itholic/error_context_for_dataframe_api. Authored-by: Haejoon Lee <haejoon....@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- python/pyspark/errors/exceptions/captured.py | 8 + python/pyspark/sql/column.py | 37 +- .../sql/tests/connect/test_parity_dataframe.py | 4 + python/pyspark/sql/tests/test_dataframe.py | 485 +++++++++++++++++++++ python/pyspark/testing/utils.py | 30 ++ .../apache/spark/sql/catalyst/parser/parsers.scala | 2 +- .../spark/sql/catalyst/trees/QueryContexts.scala | 18 +- .../apache/spark/sql/catalyst/trees/origin.scala | 5 +- .../main/scala/org/apache/spark/sql/Column.scala | 23 + .../main/scala/org/apache/spark/sql/package.scala | 73 +++- 10 files changed, 669 insertions(+), 16 deletions(-) diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index e5ec257fb32e..2a30eba3fb22 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -409,5 +409,13 @@ class QueryContext(BaseQueryContext): def callSite(self) -> str: return str(self._q.callSite()) + def pysparkFragment(self) -> Optional[str]: # type: ignore[return] + if self.contextType() == QueryContextType.DataFrame: + return str(self._q.pysparkFragment()) + + def pysparkCallSite(self) -> Optional[str]: # type: ignore[return] + if self.contextType() == QueryContextType.DataFrame: + return str(self._q.pysparkCallSite()) + def summary(self) -> str: return str(self._q.summary()) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 31c1013742a0..fb266b03c2ff 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -18,6 +18,7 @@ import sys import json import warnings +import inspect from typing import ( cast, overload, @@ -174,16 +175,50 @@ def _bin_op( ["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column" ]: """Create a method for given binary operator""" + binary_operator_map = { + "plus": "+", + "minus": "-", + "divide": "/", + "multiply": "*", + "mod": "%", + "equalTo": "=", + "lt": "<", + "leq": "<=", + "geq": ">=", + "gt": ">", + "eqNullSafe": "<=>", + "bitwiseOR": "|", + "bitwiseAND": "&", + "bitwiseXOR": "^", + # Just following JVM rule even if the names of source and target are the same. + "and": "and", + "or": "or", + } def _( self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], ) -> "Column": jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, name)(jc) + if name in binary_operator_map: + from pyspark.sql import SparkSession + + spark = SparkSession._getActiveSessionOrCreate() + stack = list(reversed(inspect.stack())) + depth = int( + spark.conf.get("spark.sql.stackTracesInDataFrameContext") # type: ignore[arg-type] + ) + selected_frames = stack[:depth] + call_sites = [f"{frame.filename}:{frame.lineno}" for frame in selected_frames] + call_site_str = "\n".join(call_sites) + + njc = getattr(self._jc, "fn")(binary_operator_map[name], jc, name, call_site_str) + else: + njc = getattr(self._jc, name)(jc) return Column(njc) _.__doc__ = doc + _.__name__ = name return _ diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 343f485553a9..6210d4ec72fe 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -30,6 +30,10 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): def test_toDF_with_schema_string(self): super().test_toDF_with_schema_string() + @unittest.skip("Spark Connect does not support DataFrameQueryContext currently.") + def test_dataframe_error_context(self): + super().test_dataframe_error_context() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 1eccb40e709c..3f6a8eece5b0 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -37,6 +37,9 @@ from pyspark.errors import ( AnalysisException, IllegalArgumentException, PySparkTypeError, + ArithmeticException, + QueryContextType, + NumberFormatException, ) from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -832,6 +835,488 @@ class DataFrameTestsMixin: self.assertEqual(df.schema, schema) self.assertEqual(df.collect(), data) + def test_dataframe_error_context(self): + # SPARK-47274: Add more useful contexts for PySpark DataFrame API errors. + with self.sql_conf({"spark.sql.ansi.enabled": True}): + df = self.spark.range(10) + + # DataFrameQueryContext with pysparkLoggingInfo - divide + with self.assertRaises(ArithmeticException) as pe: + df.withColumn("div_zero", df.id / 0).collect() + self.check_error( + exception=pe.exception, + error_class="DIVIDE_BY_ZERO", + message_parameters={"config": '"spark.sql.ansi.enabled"'}, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="divide", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - plus + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("plus_invalid_type", df.id + "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="plus", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - minus + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("minus_invalid_type", df.id - "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="minus", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - multiply + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("multiply_invalid_type", df.id * "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="multiply", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - mod + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("mod_invalid_type", df.id % "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="mod", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - equalTo + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("equalTo_invalid_type", df.id == "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="equalTo", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - lt + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("lt_invalid_type", df.id < "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="lt", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - leq + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("leq_invalid_type", df.id <= "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="leq", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - geq + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("geq_invalid_type", df.id >= "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="geq", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - gt + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("gt_invalid_type", df.id > "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="gt", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("eqNullSafe_invalid_type", df.id.eqNullSafe("string")).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="eqNullSafe", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - and + with self.assertRaises(AnalysisException) as pe: + df.withColumn("and_invalid_type", df.id & "string").collect() + self.check_error( + exception=pe.exception, + error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE", + message_parameters={ + "inputType": '"BOOLEAN"', + "actualDataType": '"BIGINT"', + "sqlExpr": '"(id AND string)"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="and", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - or + with self.assertRaises(AnalysisException) as pe: + df.withColumn("or_invalid_type", df.id | "string").collect() + self.check_error( + exception=pe.exception, + error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE", + message_parameters={ + "inputType": '"BOOLEAN"', + "actualDataType": '"BIGINT"', + "sqlExpr": '"(id OR string)"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="or", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("bitwiseOR_invalid_type", df.id.bitwiseOR("string")).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="bitwiseOR", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("bitwiseAND_invalid_type", df.id.bitwiseAND("string")).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="bitwiseAND", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("bitwiseXOR_invalid_type", df.id.bitwiseXOR("string")).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="bitwiseXOR", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - chained (`divide` is problematic) + with self.assertRaises(ArithmeticException) as pe: + df.withColumn("multiply_ten", df.id * 10).withColumn( + "divide_zero", df.id / 0 + ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect() + self.check_error( + exception=pe.exception, + error_class="DIVIDE_BY_ZERO", + message_parameters={"config": '"spark.sql.ansi.enabled"'}, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="divide", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` is problematic) + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("multiply_ten", df.id * 10).withColumn( + "divide_ten", df.id / 10 + ).withColumn("plus_string", df.id + "string").withColumn( + "minus_ten", df.id - 10 + ).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="plus", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` is problematic) + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("multiply_ten", df.id * 10).withColumn( + "divide_ten", df.id / 10 + ).withColumn("plus_ten", df.id + 10).withColumn( + "minus_string", df.id - "string" + ).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="minus", + ) + + # DataFrameQueryContext with pysparkLoggingInfo - chained (`multiply` is problematic) + with self.assertRaises(NumberFormatException) as pe: + df.withColumn("multiply_string", df.id * "string").withColumn( + "divide_ten", df.id / 10 + ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="multiply", + ) + + # Multiple expressions in df.select (`divide` is problematic) + with self.assertRaises(ArithmeticException) as pe: + df.select(df.id - 10, df.id + 4, df.id / 0, df.id * 5).collect() + self.check_error( + exception=pe.exception, + error_class="DIVIDE_BY_ZERO", + message_parameters={"config": '"spark.sql.ansi.enabled"'}, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="divide", + ) + + # Multiple expressions in df.select (`plus` is problematic) + with self.assertRaises(NumberFormatException) as pe: + df.select(df.id - 10, df.id + "string", df.id / 10, df.id * 5).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="plus", + ) + + # Multiple expressions in df.select (`minus` is problematic) + with self.assertRaises(NumberFormatException) as pe: + df.select(df.id - "string", df.id + 4, df.id / 10, df.id * 5).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="minus", + ) + + # Multiple expressions in df.select (`multiply` is problematic) + with self.assertRaises(NumberFormatException) as pe: + df.select(df.id - 10, df.id + 4, df.id / 10, df.id * "string").collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="multiply", + ) + + # Multiple expressions with pre-declared expressions (`divide` is problematic) + a = df.id / 10 + b = df.id / 0 + with self.assertRaises(ArithmeticException) as pe: + df.select(a, df.id + 4, b, df.id * 5).collect() + self.check_error( + exception=pe.exception, + error_class="DIVIDE_BY_ZERO", + message_parameters={"config": '"spark.sql.ansi.enabled"'}, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="divide", + ) + + # Multiple expressions with pre-declared expressions (`plus` is problematic) + a = df.id + "string" + b = df.id + 4 + with self.assertRaises(NumberFormatException) as pe: + df.select(df.id / 10, a, b, df.id * 5).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="plus", + ) + + # Multiple expressions with pre-declared expressions (`minus` is problematic) + a = df.id - "string" + b = df.id - 5 + with self.assertRaises(NumberFormatException) as pe: + df.select(a, df.id / 10, b, df.id * 5).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="minus", + ) + + # Multiple expressions with pre-declared expressions (`multiply` is problematic) + a = df.id * "string" + b = df.id * 10 + with self.assertRaises(NumberFormatException) as pe: + df.select(a, df.id / 10, b, df.id + 5).collect() + self.check_error( + exception=pe.exception, + error_class="CAST_INVALID_INPUT", + message_parameters={ + "expression": "'string'", + "sourceType": '"STRING"', + "targetType": '"BIGINT"', + "ansiConfig": '"spark.sql.ansi.enabled"', + }, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="multiply", + ) + + # DataFrameQueryContext without pysparkLoggingInfo + with self.assertRaises(AnalysisException) as pe: + df.select("non-existing-column") + self.check_error( + exception=pe.exception, + error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION", + message_parameters={"objectName": "`non-existing-column`", "proposal": "`id`"}, + query_context_type=QueryContextType.DataFrame, + pyspark_fragment="", + ) + + # SQLQueryContext + with self.assertRaises(ArithmeticException) as pe: + self.spark.sql("select 10/0").collect() + self.check_error( + exception=pe.exception, + error_class="DIVIDE_BY_ZERO", + message_parameters={"config": '"spark.sql.ansi.enabled"'}, + query_context_type=QueryContextType.SQL, + ) + + # No QueryContext + with self.assertRaises(AnalysisException) as pe: + self.spark.sql("select * from non-existing-table") + self.check_error( + exception=pe.exception, + error_class="INVALID_IDENTIFIER", + message_parameters={"ident": "non-existing-table"}, + query_context_type=None, + ) + class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index de40685dedc0..fe25136864ee 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -54,6 +54,8 @@ except ImportError: from pyspark import SparkConf from pyspark.errors import PySparkAssertionError, PySparkException +from pyspark.errors.exceptions.captured import CapturedException +from pyspark.errors.exceptions.base import QueryContextType from pyspark.find_spark_home import _find_spark_home from pyspark.sql.dataframe import DataFrame from pyspark.sql import Row @@ -284,7 +286,14 @@ class PySparkErrorTestUtils: exception: PySparkException, error_class: str, message_parameters: Optional[Dict[str, str]] = None, + query_context_type: Optional[QueryContextType] = None, + pyspark_fragment: Optional[str] = None, ): + query_context = exception.getQueryContext() + assert bool(query_context) == (query_context_type is not None), ( + "`query_context_type` is required when QueryContext exists. " + f"QueryContext: {query_context}." + ) # Test if given error is an instance of PySparkException. self.assertIsInstance( exception, @@ -306,6 +315,27 @@ class PySparkErrorTestUtils: expected, actual, f"Expected message parameters was '{expected}', got '{actual}'" ) + # Test query context + if query_context: + expected = query_context_type + actual_contexts = exception.getQueryContext() + for actual_context in actual_contexts: + actual = actual_context.contextType() + self.assertEqual( + expected, actual, f"Expected QueryContext was '{expected}', got '{actual}'" + ) + if actual == QueryContextType.DataFrame: + assert ( + pyspark_fragment is not None + ), "`pyspark_fragment` is required when QueryContextType is DataFrame." + expected = pyspark_fragment + actual = actual_context.pysparkFragment() + self.assertEqual( + expected, + actual, + f"Expected PySpark fragment was '{expected}', got '{actual}'", + ) + def assertSchemaEqual( actual: StructType, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index 6cfa7ed195a7..0a84ecd8203f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -250,7 +250,7 @@ class ParseException private( val builder = new StringBuilder builder ++= "\n" ++= message start match { - case Origin(Some(l), Some(p), _, _, _, _, _, _) => + case Origin(Some(l), Some(p), _, _, _, _, _, _, _) => builder ++= s" (line $l, pos $p)\n" command.foreach { cmd => val (above, below) = cmd.split("\n").splitAt(l) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index c716002ef35c..1c2456f00bcd 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -134,7 +134,9 @@ case class SQLQueryContext( override def callSite: String = throw SparkUnsupportedOperationException() } -case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends QueryContext { +case class DataFrameQueryContext( + stackTrace: Seq[StackTraceElement], + pysparkErrorContext: Option[(String, String)]) extends QueryContext { override val contextType = QueryContextType.DataFrame override def objectType: String = throw SparkUnsupportedOperationException() @@ -155,16 +157,26 @@ case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends Que override val callSite: String = stackTrace.tail.mkString("\n") + val pysparkFragment: String = pysparkErrorContext.map(_._1).getOrElse("") + val pysparkCallSite: String = pysparkErrorContext.map(_._2).getOrElse("") + + val (displayedFragment, displayedCallsite) = if (pysparkErrorContext.nonEmpty) { + (pysparkFragment, pysparkCallSite) + } else { + (fragment, callSite) + } + override lazy val summary: String = { val builder = new StringBuilder builder ++= "== DataFrame ==\n" builder ++= "\"" - builder ++= fragment + builder ++= displayedFragment builder ++= "\"" builder ++= " was called from\n" - builder ++= callSite + builder ++= displayedCallsite builder += '\n' + builder.result() } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index d8469d3056d5..9d3968b02535 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -32,10 +32,11 @@ case class Origin( sqlText: Option[String] = None, objectType: Option[String] = None, objectName: Option[String] = None, - stackTrace: Option[Array[StackTraceElement]] = None) { + stackTrace: Option[Array[StackTraceElement]] = None, + pysparkErrorContext: Option[(String, String)] = None) { lazy val context: QueryContext = if (stackTrace.isDefined) { - DataFrameQueryContext(stackTrace.get.toImmutableArraySeq) + DataFrameQueryContext(stackTrace.get.toImmutableArraySeq, pysparkErrorContext) } else { SQLQueryContext( line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index fdd315a44f1e..22c09c51c237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -171,6 +171,29 @@ class Column(val expr: Expression) extends Logging { Column.fn(name, this, lit(other)) } + /** + * A version of the `fn` method specifically designed for binary operations in PySpark + * that require logging information. + * This method is used when the operation involves another Column. + * + * @param name The name of the operation to be performed. + * @param other The value to be used in the operation, which will be converted to a + * Column if not already one. + * @param pysparkFragment A string representing the 'fragment' of the PySpark error context, + * typically indicates the name of PySpark function. + * @param pysparkCallSite A string representing the 'callSite' of the PySpark error context, + * providing the exact location within the PySpark code where the + * operation originated. + * @return A Column resulting from the operation. + */ + private def fn( + name: String, other: Any, pysparkFragment: String, pysparkCallSite: String): Column = { + val tupleInfo = (pysparkFragment, pysparkCallSite) + withOrigin(Some(tupleInfo)) { + Column.fn(name, this, lit(other)) + } + } + override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 9831ce62801a..1444eea09b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -78,6 +78,31 @@ package object sql { */ private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = "org.apache.spark.legacyINT96" + /** + * Captures the current Java stack trace up to a specified depth defined by the + * `spark.sql.stackTracesInDataFrameContext` configuration. This method helps in identifying + * the call sites in Spark code by filtering out the stack frames until it reaches the + * user code calling into Spark. This method is intended to be used for enhancing debuggability + * by providing detailed context about where in the Spark source code a particular operation + * was called from. + * + * This functionality is crucial for both debugging purposes and for providing more insightful + * logging and error messages. By capturing the stack trace up to a certain depth, it enables + * a more precise pinpointing of the execution flow, especially useful when troubleshooting + * complex interactions within Spark. + * + * @return An array of `StackTraceElement` representing the filtered stack trace. + */ + private def captureStackTrace(): Array[StackTraceElement] = { + val st = Thread.currentThread().getStackTrace + var i = 0 + // Find the beginning of Spark code traces + while (i < st.length && !sparkCode(st(i))) i += 1 + // Stop at the end of the first Spark code traces + while (i < st.length && sparkCode(st(i))) i += 1 + st.slice(from = i - 1, until = i + SQLConf.get.stackTracesInDataFrameContext) + } + /** * This helper function captures the Spark API and its call site in the user code from the current * stacktrace. @@ -98,15 +123,45 @@ package object sql { if (CurrentOrigin.get.stackTrace.isDefined) { f } else { - val st = Thread.currentThread().getStackTrace - var i = 0 - // Find the beginning of Spark code traces - while (i < st.length && !sparkCode(st(i))) i += 1 - // Stop at the end of the first Spark code traces - while (i < st.length && sparkCode(st(i))) i += 1 - val origin = Origin(stackTrace = Some(st.slice( - from = i - 1, - until = i + SQLConf.get.stackTracesInDataFrameContext))) + val origin = Origin(stackTrace = Some(captureStackTrace())) + CurrentOrigin.withOrigin(origin)(f) + } + } + + /** + * This overloaded helper function captures the call site information specifically for PySpark, + * using provided PySpark logging information instead of capturing the current Java stack trace. + * + * This method is designed to enhance the debuggability of PySpark by including PySpark-specific + * logging information (e.g., method names and call sites within PySpark scripts) in debug logs, + * without the overhead of capturing and processing Java stack traces that are less relevant + * to PySpark developers. + * + * The `pysparkErrorContext` parameter allows for passing PySpark call site information, which + * is then included in the Origin context. This facilitates more precise and useful logging for + * troubleshooting PySpark applications. + * + * This method should be used in places where PySpark API calls are made, and PySpark logging + * information is available and beneficial for debugging purposes. + * + * @param pysparkErrorContext Optional PySpark logging information including the call site, + * represented as a (String, String). + * This may contain keys like "fragment" and "callSite" to provide + * detailed context about the PySpark call site. + * @param f The function that can utilize the modified Origin context with + * PySpark logging information. + * @return The result of executing `f` within the context of the provided PySpark logging + * information. + */ + private[sql] def withOrigin[T]( + pysparkErrorContext: Option[(String, String)] = None)(f: => T): T = { + if (CurrentOrigin.get.stackTrace.isDefined) { + f + } else { + val origin = Origin( + stackTrace = Some(captureStackTrace()), + pysparkErrorContext = pysparkErrorContext + ) CurrentOrigin.withOrigin(origin)(f) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org