itholic commented on code in PR #48364:
URL: https://github.com/apache/spark/pull/48364#discussion_r1798721118
##########
python/pyspark/sql/tests/test_connect_compatibility.py:
##########
@@ -23,155 +23,165 @@
from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
from pyspark.sql.classic.column import Column as ClassicColumn
from pyspark.sql.session import SparkSession as ClassicSparkSession
+from pyspark.sql.catalog import Catalog as ClassicCatalog
if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
+ from pyspark.sql.connect.catalog import Catalog as ConnectCatalog
class ConnectCompatibilityTestsMixin:
- def get_public_methods(self, cls):
+ def _get_public_methods(self, cls):
"""Get public methods of a class."""
return {
name: method
for name, method in inspect.getmembers(cls,
predicate=inspect.isfunction)
if not name.startswith("_")
}
- def get_public_properties(self, cls):
+ def _get_public_properties(self, cls):
"""Get public properties of a class."""
return {
name: member
for name, member in inspect.getmembers(cls)
if isinstance(member, property) and not name.startswith("_")
}
- def test_signature_comparison_between_classic_and_connect(self):
- def compare_method_signatures(classic_cls, connect_cls, cls_name):
- """Compare method signatures between classic and connect
classes."""
- classic_methods = self.get_public_methods(classic_cls)
- connect_methods = self.get_public_methods(connect_cls)
-
- common_methods = set(classic_methods.keys()) &
set(connect_methods.keys())
-
- for method in common_methods:
- classic_signature = inspect.signature(classic_methods[method])
- connect_signature = inspect.signature(connect_methods[method])
-
- # createDataFrame cannot be the same since RDD is not
supported from Spark Connect
- if not method == "createDataFrame":
- self.assertEqual(
- classic_signature,
- connect_signature,
- f"Signature mismatch in {cls_name} method '{method}'\n"
- f"Classic: {classic_signature}\n"
- f"Connect: {connect_signature}",
- )
-
- # DataFrame API signature comparison
- compare_method_signatures(ClassicDataFrame, ConnectDataFrame,
"DataFrame")
-
- # Column API signature comparison
- compare_method_signatures(ClassicColumn, ConnectColumn, "Column")
-
- # SparkSession API signature comparison
- compare_method_signatures(ClassicSparkSession, ConnectSparkSession,
"SparkSession")
-
- def test_property_comparison_between_classic_and_connect(self):
- def compare_property_lists(classic_cls, connect_cls, cls_name,
expected_missing_properties):
- """Compare properties between classic and connect classes."""
- classic_properties = self.get_public_properties(classic_cls)
- connect_properties = self.get_public_properties(connect_cls)
-
- # Identify missing properties
- classic_only_properties = set(classic_properties.keys()) - set(
- connect_properties.keys()
- )
-
- # Compare the actual missing properties with the expected ones
- self.assertEqual(
- classic_only_properties,
- expected_missing_properties,
- f"{cls_name}: Unexpected missing properties in Connect:
{classic_only_properties}",
- )
-
- # Expected missing properties for DataFrame
- expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"}
-
- # DataFrame properties comparison
- compare_property_lists(
- ClassicDataFrame,
- ConnectDataFrame,
- "DataFrame",
- expected_missing_properties_for_dataframe,
+ def _compare_method_signatures(self, classic_cls, connect_cls, cls_name):
+ """Compare method signatures between classic and connect classes."""
+ classic_methods = self._get_public_methods(classic_cls)
+ connect_methods = self._get_public_methods(connect_cls)
+
+ common_methods = set(classic_methods.keys()) &
set(connect_methods.keys())
+
+ for method in common_methods:
+ classic_signature = inspect.signature(classic_methods[method])
+ connect_signature = inspect.signature(connect_methods[method])
+
+ if not method == "createDataFrame":
+ self.assertEqual(
+ classic_signature,
+ connect_signature,
+ f"Signature mismatch in {cls_name} method '{method}'\n"
+ f"Classic: {classic_signature}\n"
+ f"Connect: {connect_signature}",
+ )
+
+ def _compare_property_lists(
+ self, classic_cls, connect_cls, cls_name, expected_missing_properties
+ ):
+ """Compare properties between classic and connect classes."""
+ classic_properties = self._get_public_properties(classic_cls)
+ connect_properties = self._get_public_properties(connect_cls)
+
+ # Identify missing properties
+ classic_only_properties = set(classic_properties.keys()) -
set(connect_properties.keys())
+
+ # Compare the actual missing properties with the expected ones
+ self.assertEqual(
+ classic_only_properties,
+ expected_missing_properties,
+ f"{cls_name}: Unexpected missing properties in Connect:
{classic_only_properties}",
)
- # Expected missing properties for Column (if any, replace with actual
values)
- expected_missing_properties_for_column = set()
-
- # Column properties comparison
- compare_property_lists(
- ClassicColumn, ConnectColumn, "Column",
expected_missing_properties_for_column
- )
+ def _check_missing_methods(self, classic_cls, connect_cls, cls_name,
expected_missing_methods):
+ """Check for expected missing methods between classic and connect
classes."""
+ classic_methods = self._get_public_methods(classic_cls)
+ connect_methods = self._get_public_methods(connect_cls)
- # Expected missing properties for SparkSession
- expected_missing_properties_for_spark_session = {"sparkContext",
"version"}
+ # Identify missing methods
+ classic_only_methods = set(classic_methods.keys()) -
set(connect_methods.keys())
- # SparkSession properties comparison
- compare_property_lists(
- ClassicSparkSession,
- ConnectSparkSession,
- "SparkSession",
- expected_missing_properties_for_spark_session,
+ # Compare the actual missing methods with the expected ones
+ self.assertEqual(
+ classic_only_methods,
+ expected_missing_methods,
+ f"{cls_name}: Unexpected missing methods in Connect:
{classic_only_methods}",
)
- def test_missing_methods(self):
- def check_missing_methods(classic_cls, connect_cls, cls_name,
expected_missing_methods):
- """Check for expected missing methods between classic and connect
classes."""
- classic_methods = self.get_public_methods(classic_cls)
- connect_methods = self.get_public_methods(connect_cls)
-
- # Identify missing methods
- classic_only_methods = set(classic_methods.keys()) -
set(connect_methods.keys())
-
- # Compare the actual missing methods with the expected ones
- self.assertEqual(
- classic_only_methods,
- expected_missing_methods,
- f"{cls_name}: Unexpected missing methods in Connect:
{classic_only_methods}",
- )
-
- # Expected missing methods for DataFrame
- expected_missing_methods_for_dataframe = {
- "inputFiles",
- "isLocal",
- "semanticHash",
- "isEmpty",
- }
-
- # DataFrame missing method check
- check_missing_methods(
- ClassicDataFrame, ConnectDataFrame, "DataFrame",
expected_missing_methods_for_dataframe
+ def check_compatibility(
+ self,
+ classic_cls,
+ connect_cls,
+ cls_name,
+ expected_missing_properties,
+ expected_missing_methods,
+ ):
+ """
+ Main method for checking compatibility between classic and connect.
+
+ This method performs the following checks:
+ - API signature comparison between classic and connect classes.
+ - Property comparison, identifying any missing properties in the
connect class.
+ - Method comparison, identifying any missing methods in the connect
class.
+
+ Parameters
+ ----------
+ classic_cls : type
+ The classic class to compare.
+ connect_cls : type
+ The connect class to compare.
+ cls_name : str
+ The name of the class.
+ expected_missing_properties : set
+ A set of properties expected to be missing in the connect class.
+ expected_missing_methods : set
+ A set of methods expected to be missing in the connect class.
+ """
+ self._compare_method_signatures(classic_cls, connect_cls, cls_name)
+ self._compare_property_lists(
+ classic_cls, connect_cls, cls_name, expected_missing_properties
)
+ self._check_missing_methods(classic_cls, connect_cls, cls_name,
expected_missing_methods)
- # Expected missing methods for Column (if any, replace with actual
values)
- expected_missing_methods_for_column = set()
-
- # Column missing method check
- check_missing_methods(
- ClassicColumn, ConnectColumn, "Column",
expected_missing_methods_for_column
+ def test_dataframe_compatibility(self):
+ """Test DataFrame compatibility between classic and connect."""
+ expected_missing_properties = {"sql_ctx", "isStreaming"}
+ expected_missing_methods = {"inputFiles", "isLocal", "semanticHash",
"isEmpty"}
Review Comment:
Just realized that there is a bug in test functions. They all should be
supported (except for `sql_ctx`).
Let me fix the test, thanks for catching this out!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]