itholic commented on code in PR #48364:
URL: https://github.com/apache/spark/pull/48364#discussion_r1798734222


##########
python/pyspark/sql/tests/test_connect_compatibility.py:
##########
@@ -23,155 +23,165 @@
 from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
 from pyspark.sql.classic.column import Column as ClassicColumn
 from pyspark.sql.session import SparkSession as ClassicSparkSession
+from pyspark.sql.catalog import Catalog as ClassicCatalog
 
 if should_test_connect:
     from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
     from pyspark.sql.connect.column import Column as ConnectColumn
     from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
+    from pyspark.sql.connect.catalog import Catalog as ConnectCatalog
 
 
 class ConnectCompatibilityTestsMixin:
-    def get_public_methods(self, cls):
+    def _get_public_methods(self, cls):
         """Get public methods of a class."""
         return {
             name: method
             for name, method in inspect.getmembers(cls, 
predicate=inspect.isfunction)
             if not name.startswith("_")
         }
 
-    def get_public_properties(self, cls):
+    def _get_public_properties(self, cls):
         """Get public properties of a class."""
         return {
             name: member
             for name, member in inspect.getmembers(cls)
             if isinstance(member, property) and not name.startswith("_")
         }
 
-    def test_signature_comparison_between_classic_and_connect(self):
-        def compare_method_signatures(classic_cls, connect_cls, cls_name):
-            """Compare method signatures between classic and connect 
classes."""
-            classic_methods = self.get_public_methods(classic_cls)
-            connect_methods = self.get_public_methods(connect_cls)
-
-            common_methods = set(classic_methods.keys()) & 
set(connect_methods.keys())
-
-            for method in common_methods:
-                classic_signature = inspect.signature(classic_methods[method])
-                connect_signature = inspect.signature(connect_methods[method])
-
-                # createDataFrame cannot be the same since RDD is not 
supported from Spark Connect
-                if not method == "createDataFrame":
-                    self.assertEqual(
-                        classic_signature,
-                        connect_signature,
-                        f"Signature mismatch in {cls_name} method '{method}'\n"
-                        f"Classic: {classic_signature}\n"
-                        f"Connect: {connect_signature}",
-                    )
-
-        # DataFrame API signature comparison
-        compare_method_signatures(ClassicDataFrame, ConnectDataFrame, 
"DataFrame")
-
-        # Column API signature comparison
-        compare_method_signatures(ClassicColumn, ConnectColumn, "Column")
-
-        # SparkSession API signature comparison
-        compare_method_signatures(ClassicSparkSession, ConnectSparkSession, 
"SparkSession")
-
-    def test_property_comparison_between_classic_and_connect(self):
-        def compare_property_lists(classic_cls, connect_cls, cls_name, 
expected_missing_properties):
-            """Compare properties between classic and connect classes."""
-            classic_properties = self.get_public_properties(classic_cls)
-            connect_properties = self.get_public_properties(connect_cls)
-
-            # Identify missing properties
-            classic_only_properties = set(classic_properties.keys()) - set(
-                connect_properties.keys()
-            )
-
-            # Compare the actual missing properties with the expected ones
-            self.assertEqual(
-                classic_only_properties,
-                expected_missing_properties,
-                f"{cls_name}: Unexpected missing properties in Connect: 
{classic_only_properties}",
-            )
-
-        # Expected missing properties for DataFrame
-        expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"}
-
-        # DataFrame properties comparison
-        compare_property_lists(
-            ClassicDataFrame,
-            ConnectDataFrame,
-            "DataFrame",
-            expected_missing_properties_for_dataframe,
+    def _compare_method_signatures(self, classic_cls, connect_cls, cls_name):
+        """Compare method signatures between classic and connect classes."""
+        classic_methods = self._get_public_methods(classic_cls)
+        connect_methods = self._get_public_methods(connect_cls)
+
+        common_methods = set(classic_methods.keys()) & 
set(connect_methods.keys())
+
+        for method in common_methods:
+            classic_signature = inspect.signature(classic_methods[method])
+            connect_signature = inspect.signature(connect_methods[method])
+
+            if not method == "createDataFrame":
+                self.assertEqual(
+                    classic_signature,
+                    connect_signature,
+                    f"Signature mismatch in {cls_name} method '{method}'\n"
+                    f"Classic: {classic_signature}\n"
+                    f"Connect: {connect_signature}",
+                )
+
+    def _compare_property_lists(
+        self, classic_cls, connect_cls, cls_name, expected_missing_properties
+    ):
+        """Compare properties between classic and connect classes."""
+        classic_properties = self._get_public_properties(classic_cls)
+        connect_properties = self._get_public_properties(connect_cls)
+
+        # Identify missing properties
+        classic_only_properties = set(classic_properties.keys()) - 
set(connect_properties.keys())
+
+        # Compare the actual missing properties with the expected ones
+        self.assertEqual(
+            classic_only_properties,
+            expected_missing_properties,
+            f"{cls_name}: Unexpected missing properties in Connect: 
{classic_only_properties}",
         )
 
-        # Expected missing properties for Column (if any, replace with actual 
values)
-        expected_missing_properties_for_column = set()
-
-        # Column properties comparison
-        compare_property_lists(
-            ClassicColumn, ConnectColumn, "Column", 
expected_missing_properties_for_column
-        )
+    def _check_missing_methods(self, classic_cls, connect_cls, cls_name, 
expected_missing_methods):
+        """Check for expected missing methods between classic and connect 
classes."""
+        classic_methods = self._get_public_methods(classic_cls)
+        connect_methods = self._get_public_methods(connect_cls)
 
-        # Expected missing properties for SparkSession
-        expected_missing_properties_for_spark_session = {"sparkContext", 
"version"}
+        # Identify missing methods
+        classic_only_methods = set(classic_methods.keys()) - 
set(connect_methods.keys())
 
-        # SparkSession properties comparison
-        compare_property_lists(
-            ClassicSparkSession,
-            ConnectSparkSession,
-            "SparkSession",
-            expected_missing_properties_for_spark_session,
+        # Compare the actual missing methods with the expected ones
+        self.assertEqual(
+            classic_only_methods,
+            expected_missing_methods,
+            f"{cls_name}: Unexpected missing methods in Connect: 
{classic_only_methods}",
         )
 
-    def test_missing_methods(self):
-        def check_missing_methods(classic_cls, connect_cls, cls_name, 
expected_missing_methods):
-            """Check for expected missing methods between classic and connect 
classes."""
-            classic_methods = self.get_public_methods(classic_cls)
-            connect_methods = self.get_public_methods(connect_cls)
-
-            # Identify missing methods
-            classic_only_methods = set(classic_methods.keys()) - 
set(connect_methods.keys())
-
-            # Compare the actual missing methods with the expected ones
-            self.assertEqual(
-                classic_only_methods,
-                expected_missing_methods,
-                f"{cls_name}: Unexpected missing methods in Connect: 
{classic_only_methods}",
-            )
-
-        # Expected missing methods for DataFrame
-        expected_missing_methods_for_dataframe = {
-            "inputFiles",
-            "isLocal",
-            "semanticHash",
-            "isEmpty",
-        }
-
-        # DataFrame missing method check
-        check_missing_methods(
-            ClassicDataFrame, ConnectDataFrame, "DataFrame", 
expected_missing_methods_for_dataframe
+    def check_compatibility(
+        self,
+        classic_cls,
+        connect_cls,
+        cls_name,
+        expected_missing_properties,
+        expected_missing_methods,
+    ):
+        """
+        Main method for checking compatibility between classic and connect.
+
+        This method performs the following checks:
+        - API signature comparison between classic and connect classes.
+        - Property comparison, identifying any missing properties in the 
connect class.

Review Comment:
   Just updated the test to check revert compatibility as well. Thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to