This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b6ff4aa237c [SPARK-41799][CONNECT][PYTHON][TESTS] Combine plan-related 
tests into single file
b6ff4aa237c is described below

commit b6ff4aa237cd4dcce20d6244295d038a7d3cfab7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Jan 1 16:06:50 2023 +0800

    [SPARK-41799][CONNECT][PYTHON][TESTS] Combine plan-related tests into 
single file
    
    ### What changes were proposed in this pull request?
    Combine plan-related tests into single file
    
    ### Why are the changes needed?
    1, `test_connect_column_expressions`, `test_connect_plan_only`, 
`test_connect_select_ops` almost did the same thing: generate and then validate 
a plan;
    
    2, the three tests are pretty small, and normally finished in 1 sec.
    
    ```
    Starting test(python3.9): 
pyspark.sql.tests.connect.test_connect_column_expressions (temp output: 
/__w/spark/spark/python/target/8d15e343-028e-44de-b998-6a7e7cc98047/python3.9__pyspark.sql.tests.connect.test_connect_column_expressions__fkvdhn26.log)
    Finished test(python3.9): 
pyspark.sql.tests.connect.test_connect_column_expressions (0s)
    Starting test(python3.9): pyspark.sql.tests.connect.test_connect_function 
(temp output: 
/__w/spark/spark/python/target/b35958dc-4dd3-4420-8f44-cb6f66d568dc/python3.9__pyspark.sql.tests.connect.test_connect_function__te14hoiz.log)
    Finished test(python3.9): pyspark.sql.tests.connect.test_connect_function 
(80s)
    Starting test(python3.9): pyspark.sql.tests.connect.test_connect_plan_only 
(temp output: 
/__w/spark/spark/python/target/3225a48d-5b4c-4cbe-803d-680c9408e3a8/python3.9__pyspark.sql.tests.connect.test_connect_plan_only__4bjohyey.log)
    Finished test(python3.9): pyspark.sql.tests.connect.test_connect_plan_only 
(0s)
    Starting test(python3.9): pyspark.sql.tests.connect.test_connect_select_ops 
(temp output: 
/__w/spark/spark/python/target/fe6a37ff-9aa8-41d5-8204-44a86423381f/python3.9__pyspark.sql.tests.connect.test_connect_select_ops__cicvg0w7.log)
    Finished test(python3.9): pyspark.sql.tests.connect.test_connect_select_ops 
(0s)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    CI
    
    Closes #39323 from zhengruifeng/connect_test_reorg.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   4 +-
 .../sql/tests/connect/test_connect_basic.py        |  24 ++-
 .../sql/tests/connect/test_connect_column.py       |   2 +-
 .../connect/test_connect_column_expressions.py     | 195 ------------------
 ...t_connect_plan_only.py => test_connect_plan.py} | 225 ++++++++++++++++++---
 .../sql/tests/connect/test_connect_select_ops.py   |  71 -------
 6 files changed, 225 insertions(+), 296 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index df3a1f180fc..dff17792148 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -509,9 +509,7 @@ pyspark_connect = Module(
         "pyspark.sql.connect.window",
         "pyspark.sql.connect.column",
         # unittests
-        "pyspark.sql.tests.connect.test_connect_column_expressions",
-        "pyspark.sql.tests.connect.test_connect_plan_only",
-        "pyspark.sql.tests.connect.test_connect_select_ops",
+        "pyspark.sql.tests.connect.test_connect_plan",
         "pyspark.sql.tests.connect.test_connect_basic",
         "pyspark.sql.tests.connect.test_connect_function",
         "pyspark.sql.tests.connect.test_connect_column",
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 6cdef25d5bc..0b615d2e32a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -117,7 +117,7 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, 
ReusedPySparkTestCase, SQLT
         cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
 
 
-class SparkConnectTests(SparkConnectSQLTestCase):
+class SparkConnectBasicTests(SparkConnectSQLTestCase):
     def test_df_get_item(self):
         # SPARK-41779: test __getitem__
 
@@ -1746,6 +1746,28 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         ):
             cdf.groupBy("name").pivot("department").sum("salary", 
"department").show()
 
+    def test_unsupported_functions(self):
+        # SPARK-41225: Disable unsupported functions.
+        df = self.connect.read.table(self.tbl_name)
+        for f in (
+            "rdd",
+            "unpersist",
+            "cache",
+            "persist",
+            "withWatermark",
+            "observe",
+            "foreach",
+            "foreachPartition",
+            "toLocalIterator",
+            "checkpoint",
+            "localCheckpoint",
+            "_repr_html_",
+            "semanticHash",
+            "sameSemantics",
+        ):
+            with self.assertRaises(NotImplementedError):
+                getattr(df, f)()
+
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class ChannelBuilderTests(ReusedPySparkTestCase):
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 33ed1aded01..ffee64706d5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -61,7 +61,7 @@ if should_test_connect:
     from pyspark.sql.connect.functions import lit
 
 
-class SparkConnectTests(SparkConnectSQLTestCase):
+class SparkConnectColumnTests(SparkConnectSQLTestCase):
     def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
         from pyspark.sql.dataframe import DataFrame as SDF
         from pyspark.sql.connect.dataframe import DataFrame as CDF
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py 
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
deleted file mode 100644
index 55ecb859805..00000000000
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ /dev/null
@@ -1,195 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import uuid
-import unittest
-import decimal
-import datetime
-
-from pyspark.testing.connectutils import (
-    PlanOnlyTestFixture,
-    should_test_connect,
-    connect_requirement_message,
-)
-
-if should_test_connect:
-    from pyspark.sql.connect.proto import Expression as ProtoExpression
-    import pyspark.sql.connect.plan as p
-    from pyspark.sql.connect.column import Column
-    import pyspark.sql.connect.functions as fun
-
-
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
-    def test_simple_column_expressions(self):
-        df = self.connect.with_plan(p.Read("table"))
-
-        c1 = df.col_name
-        self.assertIsInstance(c1, Column)
-        c2 = df["col_name"]
-        self.assertIsInstance(c2, Column)
-        c3 = fun.col("col_name")
-        self.assertIsInstance(c3, Column)
-
-        # All Protos should be identical
-        cp1 = c1.to_plan(None)
-        cp2 = c2.to_plan(None)
-        cp3 = c3.to_plan(None)
-
-        self.assertIsNotNone(cp1)
-        self.assertEqual(cp1, cp2)
-        self.assertEqual(cp2, cp3)
-
-    def test_null_literal(self):
-        null_lit = fun.lit(None)
-        null_lit_p = null_lit.to_plan(None)
-        self.assertEqual(null_lit_p.literal.HasField("null"), True)
-
-    def test_binary_literal(self):
-        val = b"binary\0\0asas"
-        bin_lit = fun.lit(val)
-        bin_lit_p = bin_lit.to_plan(None)
-        self.assertEqual(bin_lit_p.literal.binary, val)
-
-    def test_uuid_literal(self):
-        val = uuid.uuid4()
-        with self.assertRaises(ValueError):
-            fun.lit(val)
-
-    def test_column_literals(self):
-        df = self.connect.with_plan(p.Read("table"))
-        lit_df = df.select(fun.lit(10))
-        self.assertIsNotNone(lit_df._plan.to_proto(None))
-
-        self.assertIsNotNone(fun.lit(10).to_plan(None))
-        plan = fun.lit(10).to_plan(None)
-        self.assertIs(plan.literal.integer, 10)
-
-        plan = fun.lit(1 << 33).to_plan(None)
-        self.assertEqual(plan.literal.long, 1 << 33)
-
-    def test_numeric_literal_types(self):
-        int_lit = fun.lit(10)
-        float_lit = fun.lit(10.1)
-        decimal_lit = fun.lit(decimal.Decimal(99))
-
-        self.assertIsNotNone(int_lit.to_plan(None))
-        self.assertIsNotNone(float_lit.to_plan(None))
-        self.assertIsNotNone(decimal_lit.to_plan(None))
-
-    def test_float_nan_inf(self):
-        na_lit = fun.lit(float("nan"))
-        self.assertIsNotNone(na_lit.to_plan(None))
-
-        inf_lit = fun.lit(float("inf"))
-        self.assertIsNotNone(inf_lit.to_plan(None))
-
-        inf_lit = fun.lit(float("-inf"))
-        self.assertIsNotNone(inf_lit.to_plan(None))
-
-    def test_datetime_literal_types(self):
-        """Test the different timestamp, date, and timedelta types."""
-        datetime_lit = fun.lit(datetime.datetime.now())
-
-        p = datetime_lit.to_plan(None)
-        self.assertIsNotNone(datetime_lit.to_plan(None))
-        self.assertGreater(p.literal.timestamp, 10000000000000)
-
-        date_lit = fun.lit(datetime.date.today())
-        time_delta = fun.lit(datetime.timedelta(days=1, seconds=2, 
microseconds=3))
-
-        self.assertIsNotNone(date_lit.to_plan(None))
-        self.assertIsNotNone(time_delta.to_plan(None))
-        # (24 * 3600 + 2) * 1000000 + 3
-        self.assertEqual(86402000003, 
time_delta.to_plan(None).literal.day_time_interval)
-
-    def test_list_to_literal(self):
-        """Test conversion of lists to literals"""
-        empty_list = []
-        single_type = [1, 2, 3, 4]
-        multi_type = ["ooo", 1, "asas", 2.3]
-
-        empty_list_lit = fun.lit(empty_list)
-        single_type_lit = fun.lit(single_type)
-        multi_type_lit = fun.lit(multi_type)
-
-        p = empty_list_lit.to_plan(None)
-        self.assertIsNotNone(p)
-
-        p = single_type_lit.to_plan(None)
-        self.assertIsNotNone(p)
-
-        p = multi_type_lit.to_plan(None)
-        self.assertIsNotNone(p)
-
-        lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
-        self.assertIsNotNone(lit_list_plan)
-
-    def test_column_alias(self) -> None:
-        # SPARK-40809: Support for Column Aliases
-        col0 = fun.col("a").alias("martin")
-        self.assertEqual("Column<'Alias(ColumnReference(a), (martin))'>", 
str(col0))
-
-        col0 = fun.col("a").alias("martin", metadata={"pii": True})
-        plan = col0.to_plan(self.session.client)
-        self.assertIsNotNone(plan)
-        self.assertEqual(plan.alias.metadata, '{"pii": true}')
-
-    def test_column_expressions(self):
-        """Test a more complex combination of expressions and their 
translation into
-        the protobuf structure."""
-        df = self.connect.with_plan(p.Read("table"))
-
-        expr = fun.lit(10) < fun.lit(10)
-        expr_plan = expr.to_plan(None)
-        self.assertIsNotNone(expr_plan.unresolved_function)
-        self.assertEqual(expr_plan.unresolved_function.function_name, "<")
-
-        expr = df.id % fun.lit(10) == fun.lit(10)
-        expr_plan = expr.to_plan(None)
-        self.assertIsNotNone(expr_plan.unresolved_function)
-        self.assertEqual(expr_plan.unresolved_function.function_name, "==")
-
-        lit_fun = expr_plan.unresolved_function.arguments[1]
-        self.assertIsInstance(lit_fun, ProtoExpression)
-        self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
-        self.assertEqual(lit_fun.literal.integer, 10)
-
-        mod_fun = expr_plan.unresolved_function.arguments[0]
-        self.assertIsInstance(mod_fun, ProtoExpression)
-        self.assertIsInstance(mod_fun.unresolved_function, 
ProtoExpression.UnresolvedFunction)
-        self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
-        self.assertIsInstance(mod_fun.unresolved_function.arguments[0], 
ProtoExpression)
-        self.assertIsInstance(
-            mod_fun.unresolved_function.arguments[0].unresolved_attribute,
-            ProtoExpression.UnresolvedAttribute,
-        )
-        self.assertEqual(
-            
mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier,
 "id"
-        )
-
-
-if __name__ == "__main__":
-    import unittest
-    from pyspark.sql.tests.connect.test_connect_column_expressions import *  # 
noqa: F401
-
-    try:
-        import xmlrunner
-
-        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
-    except ImportError:
-        testRunner = None
-    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan.py
similarity index 79%
rename from python/pyspark/sql/tests/connect/test_connect_plan_only.py
rename to python/pyspark/sql/tests/connect/test_connect_plan.py
index 5e3c6661e52..498c273e7f2 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -15,6 +15,9 @@
 # limitations under the License.
 #
 import unittest
+import uuid
+import datetime
+import decimal
 
 from pyspark.testing.connectutils import (
     PlanOnlyTestFixture,
@@ -26,8 +29,9 @@ if should_test_connect:
     import pyspark.sql.connect.proto as proto
     from pyspark.sql.connect.column import Column
     from pyspark.sql.connect.dataframe import DataFrame
-    from pyspark.sql.connect.plan import WriteOperation
+    from pyspark.sql.connect.plan import WriteOperation, Read
     from pyspark.sql.connect.readwriter import DataFrameReader
+    from pyspark.sql.connect.functions import col, lit
     from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
     from pyspark.sql.connect.types import pyspark_types_to_proto_types
     from pyspark.sql.types import (
@@ -41,7 +45,7 @@ if should_test_connect:
 
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
-class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
+class SparkConnectPlanTests(PlanOnlyTestFixture):
     """These test cases exercise the interface to the proto plan
     generation but do not call Spark."""
 
@@ -327,6 +331,24 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
         self.assertEqual(plan.root.freq_items.support, 0.01)
 
+    def test_freqItems(self):
+        df = self.connect.readTable(table_name=self.tbl_name)
+        plan = (
+            df.filter(df.col_name > 3).freqItems(["col_a", "col_b"], 
1)._plan.to_proto(self.connect)
+        )
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 1)
+        plan = df.filter(df.col_name > 3).freqItems(["col_a", 
"col_b"])._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 0.01)
+
+        plan = df.stat.freqItems(["col_a", "col_b"], 
1)._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 1)
+        plan = df.stat.freqItems(["col_a", 
"col_b"])._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 0.01)
+
     def test_limit(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         limit_plan = df.limit(10)._plan.to_proto(self.connect)
@@ -578,28 +600,6 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         new_plan = df.to(schema)._plan.to_proto(self.connect)
         self.assertEqual(pyspark_types_to_proto_types(schema), 
new_plan.root.to_schema.schema)
 
-    def test_unsupported_functions(self):
-        # SPARK-41225: Disable unsupported functions.
-        df = self.connect.readTable(table_name=self.tbl_name)
-        for f in (
-            "rdd",
-            "unpersist",
-            "cache",
-            "persist",
-            "withWatermark",
-            "observe",
-            "foreach",
-            "foreachPartition",
-            "toLocalIterator",
-            "checkpoint",
-            "localCheckpoint",
-            "_repr_html_",
-            "semanticHash",
-            "sameSemantics",
-        ):
-            with self.assertRaises(NotImplementedError):
-                getattr(df, f)()
-
     def test_write_operation(self):
         wo = WriteOperation(self.connect.readTable("name")._plan)
         wo.mode = "overwrite"
@@ -670,9 +670,184 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         for line in expected:
             self.assertIn(line, actual)
 
+    def test_select_with_columns_and_strings(self):
+        df = self.connect.with_plan(Read("table"))
+        
self.assertIsNotNone(df.select(col("name"))._plan.to_proto(self.connect))
+        self.assertIsNotNone(df.select("name"))
+        self.assertIsNotNone(df.select("name", "name2"))
+        self.assertIsNotNone(df.select(col("name"), col("name2")))
+        self.assertIsNotNone(df.select(col("name"), "name2"))
+        self.assertIsNotNone(df.select("*"))
+
+    def test_join_with_join_type(self):
+        df_left = self.connect.with_plan(Read("table"))
+        df_right = self.connect.with_plan(Read("table"))
+        for (join_type_str, join_type) in [
+            (None, proto.Join.JoinType.JOIN_TYPE_INNER),
+            ("inner", proto.Join.JoinType.JOIN_TYPE_INNER),
+            ("outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
+            ("leftouter", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
+            ("rightouter", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
+            ("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
+            ("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
+            ("cross", proto.Join.JoinType.JOIN_TYPE_CROSS),
+        ]:
+            joined_df = df_left.join(df_right, on=col("name"), 
how=join_type_str)._plan.to_proto(
+                self.connect
+            )
+            self.assertEqual(joined_df.root.join.join_type, join_type)
+
+    def test_simple_column_expressions(self):
+        df = self.connect.with_plan(Read("table"))
+
+        c1 = df.col_name
+        self.assertIsInstance(c1, Column)
+        c2 = df["col_name"]
+        self.assertIsInstance(c2, Column)
+        c3 = col("col_name")
+        self.assertIsInstance(c3, Column)
+
+        # All Protos should be identical
+        cp1 = c1.to_plan(None)
+        cp2 = c2.to_plan(None)
+        cp3 = c3.to_plan(None)
+
+        self.assertIsNotNone(cp1)
+        self.assertEqual(cp1, cp2)
+        self.assertEqual(cp2, cp3)
+
+    def test_null_literal(self):
+        null_lit = lit(None)
+        null_lit_p = null_lit.to_plan(None)
+        self.assertEqual(null_lit_p.literal.HasField("null"), True)
+
+    def test_binary_literal(self):
+        val = b"binary\0\0asas"
+        bin_lit = lit(val)
+        bin_lit_p = bin_lit.to_plan(None)
+        self.assertEqual(bin_lit_p.literal.binary, val)
+
+    def test_uuid_literal(self):
+
+        val = uuid.uuid4()
+        with self.assertRaises(ValueError):
+            lit(val)
+
+    def test_column_literals(self):
+        df = self.connect.with_plan(Read("table"))
+        lit_df = df.select(lit(10))
+        self.assertIsNotNone(lit_df._plan.to_proto(None))
+
+        self.assertIsNotNone(lit(10).to_plan(None))
+        plan = lit(10).to_plan(None)
+        self.assertIs(plan.literal.integer, 10)
+
+        plan = lit(1 << 33).to_plan(None)
+        self.assertEqual(plan.literal.long, 1 << 33)
+
+    def test_numeric_literal_types(self):
+        int_lit = lit(10)
+        float_lit = lit(10.1)
+        decimal_lit = lit(decimal.Decimal(99))
+
+        self.assertIsNotNone(int_lit.to_plan(None))
+        self.assertIsNotNone(float_lit.to_plan(None))
+        self.assertIsNotNone(decimal_lit.to_plan(None))
+
+    def test_float_nan_inf(self):
+        na_lit = lit(float("nan"))
+        self.assertIsNotNone(na_lit.to_plan(None))
+
+        inf_lit = lit(float("inf"))
+        self.assertIsNotNone(inf_lit.to_plan(None))
+
+        inf_lit = lit(float("-inf"))
+        self.assertIsNotNone(inf_lit.to_plan(None))
+
+    def test_datetime_literal_types(self):
+        """Test the different timestamp, date, and timedelta types."""
+        datetime_lit = lit(datetime.datetime.now())
+
+        p = datetime_lit.to_plan(None)
+        self.assertIsNotNone(datetime_lit.to_plan(None))
+        self.assertGreater(p.literal.timestamp, 10000000000000)
+
+        date_lit = lit(datetime.date.today())
+        time_delta = lit(datetime.timedelta(days=1, seconds=2, microseconds=3))
+
+        self.assertIsNotNone(date_lit.to_plan(None))
+        self.assertIsNotNone(time_delta.to_plan(None))
+        # (24 * 3600 + 2) * 1000000 + 3
+        self.assertEqual(86402000003, 
time_delta.to_plan(None).literal.day_time_interval)
+
+    def test_list_to_literal(self):
+        """Test conversion of lists to literals"""
+        empty_list = []
+        single_type = [1, 2, 3, 4]
+        multi_type = ["ooo", 1, "asas", 2.3]
+
+        empty_list_lit = lit(empty_list)
+        single_type_lit = lit(single_type)
+        multi_type_lit = lit(multi_type)
+
+        p = empty_list_lit.to_plan(None)
+        self.assertIsNotNone(p)
+
+        p = single_type_lit.to_plan(None)
+        self.assertIsNotNone(p)
+
+        p = multi_type_lit.to_plan(None)
+        self.assertIsNotNone(p)
+
+        lit_list_plan = lit([lit(10), lit("str")]).to_plan(None)
+        self.assertIsNotNone(lit_list_plan)
+
+    def test_column_alias(self) -> None:
+        # SPARK-40809: Support for Column Aliases
+        col0 = col("a").alias("martin")
+        self.assertEqual("Column<'Alias(ColumnReference(a), (martin))'>", 
str(col0))
+
+        col0 = col("a").alias("martin", metadata={"pii": True})
+        plan = col0.to_plan(self.session.client)
+        self.assertIsNotNone(plan)
+        self.assertEqual(plan.alias.metadata, '{"pii": true}')
+
+    def test_column_expressions(self):
+        """Test a more complex combination of expressions and their 
translation into
+        the protobuf structure."""
+        df = self.connect.with_plan(Read("table"))
+
+        expr = lit(10) < lit(10)
+        expr_plan = expr.to_plan(None)
+        self.assertIsNotNone(expr_plan.unresolved_function)
+        self.assertEqual(expr_plan.unresolved_function.function_name, "<")
+
+        expr = df.id % lit(10) == lit(10)
+        expr_plan = expr.to_plan(None)
+        self.assertIsNotNone(expr_plan.unresolved_function)
+        self.assertEqual(expr_plan.unresolved_function.function_name, "==")
+
+        lit_fun = expr_plan.unresolved_function.arguments[1]
+        self.assertIsInstance(lit_fun, proto.Expression)
+        self.assertIsInstance(lit_fun.literal, proto.Expression.Literal)
+        self.assertEqual(lit_fun.literal.integer, 10)
+
+        mod_fun = expr_plan.unresolved_function.arguments[0]
+        self.assertIsInstance(mod_fun, proto.Expression)
+        self.assertIsInstance(mod_fun.unresolved_function, 
proto.Expression.UnresolvedFunction)
+        self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
+        self.assertIsInstance(mod_fun.unresolved_function.arguments[0], 
proto.Expression)
+        self.assertIsInstance(
+            mod_fun.unresolved_function.arguments[0].unresolved_attribute,
+            proto.Expression.UnresolvedAttribute,
+        )
+        self.assertEqual(
+            
mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier,
 "id"
+        )
+
 
 if __name__ == "__main__":
-    from pyspark.sql.tests.connect.test_connect_plan_only import *  # noqa: 
F401
+    from pyspark.sql.tests.connect.test_connect_plan import *  # noqa: F401
 
     try:
         import xmlrunner  # type: ignore
diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py 
b/python/pyspark/sql/tests/connect/test_connect_select_ops.py
deleted file mode 100644
index 7f8153f7fca..00000000000
--- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py
+++ /dev/null
@@ -1,71 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import unittest
-
-from pyspark.testing.connectutils import (
-    PlanOnlyTestFixture,
-    should_test_connect,
-    connect_requirement_message,
-)
-
-if should_test_connect:
-    from pyspark.sql.connect.functions import col
-    from pyspark.sql.connect.plan import Read
-    import pyspark.sql.connect.proto as proto
-
-
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectToProtoSuite(PlanOnlyTestFixture):
-    def test_select_with_columns_and_strings(self):
-        df = self.connect.with_plan(Read("table"))
-        
self.assertIsNotNone(df.select(col("name"))._plan.to_proto(self.connect))
-        self.assertIsNotNone(df.select("name"))
-        self.assertIsNotNone(df.select("name", "name2"))
-        self.assertIsNotNone(df.select(col("name"), col("name2")))
-        self.assertIsNotNone(df.select(col("name"), "name2"))
-        self.assertIsNotNone(df.select("*"))
-
-    def test_join_with_join_type(self):
-        df_left = self.connect.with_plan(Read("table"))
-        df_right = self.connect.with_plan(Read("table"))
-        for (join_type_str, join_type) in [
-            (None, proto.Join.JoinType.JOIN_TYPE_INNER),
-            ("inner", proto.Join.JoinType.JOIN_TYPE_INNER),
-            ("outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
-            ("leftouter", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
-            ("rightouter", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
-            ("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
-            ("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
-            ("cross", proto.Join.JoinType.JOIN_TYPE_CROSS),
-        ]:
-            joined_df = df_left.join(df_right, on=col("name"), 
how=join_type_str)._plan.to_proto(
-                self.connect
-            )
-            self.assertEqual(joined_df.root.join.join_type, join_type)
-
-
-if __name__ == "__main__":
-    import unittest
-    from pyspark.sql.tests.connect.test_connect_select_ops import *  # noqa: 
F401
-
-    try:
-        import xmlrunner  # type: ignore
-
-        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
-    except ImportError:
-        testRunner = None
-    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to