This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b6ff4aa237c [SPARK-41799][CONNECT][PYTHON][TESTS] Combine plan-related
tests into single file
b6ff4aa237c is described below
commit b6ff4aa237cd4dcce20d6244295d038a7d3cfab7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Jan 1 16:06:50 2023 +0800
[SPARK-41799][CONNECT][PYTHON][TESTS] Combine plan-related tests into
single file
### What changes were proposed in this pull request?
Combine plan-related tests into single file
### Why are the changes needed?
1, `test_connect_column_expressions`, `test_connect_plan_only`,
`test_connect_select_ops` almost did the same thing: generate and then validate
a plan;
2, the three tests are pretty small, and normally finished in 1 sec.
```
Starting test(python3.9):
pyspark.sql.tests.connect.test_connect_column_expressions (temp output:
/__w/spark/spark/python/target/8d15e343-028e-44de-b998-6a7e7cc98047/python3.9__pyspark.sql.tests.connect.test_connect_column_expressions__fkvdhn26.log)
Finished test(python3.9):
pyspark.sql.tests.connect.test_connect_column_expressions (0s)
Starting test(python3.9): pyspark.sql.tests.connect.test_connect_function
(temp output:
/__w/spark/spark/python/target/b35958dc-4dd3-4420-8f44-cb6f66d568dc/python3.9__pyspark.sql.tests.connect.test_connect_function__te14hoiz.log)
Finished test(python3.9): pyspark.sql.tests.connect.test_connect_function
(80s)
Starting test(python3.9): pyspark.sql.tests.connect.test_connect_plan_only
(temp output:
/__w/spark/spark/python/target/3225a48d-5b4c-4cbe-803d-680c9408e3a8/python3.9__pyspark.sql.tests.connect.test_connect_plan_only__4bjohyey.log)
Finished test(python3.9): pyspark.sql.tests.connect.test_connect_plan_only
(0s)
Starting test(python3.9): pyspark.sql.tests.connect.test_connect_select_ops
(temp output:
/__w/spark/spark/python/target/fe6a37ff-9aa8-41d5-8204-44a86423381f/python3.9__pyspark.sql.tests.connect.test_connect_select_ops__cicvg0w7.log)
Finished test(python3.9): pyspark.sql.tests.connect.test_connect_select_ops
(0s)
```
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
CI
Closes #39323 from zhengruifeng/connect_test_reorg.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 4 +-
.../sql/tests/connect/test_connect_basic.py | 24 ++-
.../sql/tests/connect/test_connect_column.py | 2 +-
.../connect/test_connect_column_expressions.py | 195 ------------------
...t_connect_plan_only.py => test_connect_plan.py} | 225 ++++++++++++++++++---
.../sql/tests/connect/test_connect_select_ops.py | 71 -------
6 files changed, 225 insertions(+), 296 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index df3a1f180fc..dff17792148 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -509,9 +509,7 @@ pyspark_connect = Module(
"pyspark.sql.connect.window",
"pyspark.sql.connect.column",
# unittests
- "pyspark.sql.tests.connect.test_connect_column_expressions",
- "pyspark.sql.tests.connect.test_connect_plan_only",
- "pyspark.sql.tests.connect.test_connect_select_ops",
+ "pyspark.sql.tests.connect.test_connect_plan",
"pyspark.sql.tests.connect.test_connect_basic",
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_column",
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 6cdef25d5bc..0b615d2e32a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -117,7 +117,7 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase,
ReusedPySparkTestCase, SQLT
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
-class SparkConnectTests(SparkConnectSQLTestCase):
+class SparkConnectBasicTests(SparkConnectSQLTestCase):
def test_df_get_item(self):
# SPARK-41779: test __getitem__
@@ -1746,6 +1746,28 @@ class SparkConnectTests(SparkConnectSQLTestCase):
):
cdf.groupBy("name").pivot("department").sum("salary",
"department").show()
+ def test_unsupported_functions(self):
+ # SPARK-41225: Disable unsupported functions.
+ df = self.connect.read.table(self.tbl_name)
+ for f in (
+ "rdd",
+ "unpersist",
+ "cache",
+ "persist",
+ "withWatermark",
+ "observe",
+ "foreach",
+ "foreachPartition",
+ "toLocalIterator",
+ "checkpoint",
+ "localCheckpoint",
+ "_repr_html_",
+ "semanticHash",
+ "sameSemantics",
+ ):
+ with self.assertRaises(NotImplementedError):
+ getattr(df, f)()
+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ChannelBuilderTests(ReusedPySparkTestCase):
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 33ed1aded01..ffee64706d5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -61,7 +61,7 @@ if should_test_connect:
from pyspark.sql.connect.functions import lit
-class SparkConnectTests(SparkConnectSQLTestCase):
+class SparkConnectColumnTests(SparkConnectSQLTestCase):
def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
from pyspark.sql.dataframe import DataFrame as SDF
from pyspark.sql.connect.dataframe import DataFrame as CDF
diff --git
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
deleted file mode 100644
index 55ecb859805..00000000000
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ /dev/null
@@ -1,195 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import uuid
-import unittest
-import decimal
-import datetime
-
-from pyspark.testing.connectutils import (
- PlanOnlyTestFixture,
- should_test_connect,
- connect_requirement_message,
-)
-
-if should_test_connect:
- from pyspark.sql.connect.proto import Expression as ProtoExpression
- import pyspark.sql.connect.plan as p
- from pyspark.sql.connect.column import Column
- import pyspark.sql.connect.functions as fun
-
-
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
- def test_simple_column_expressions(self):
- df = self.connect.with_plan(p.Read("table"))
-
- c1 = df.col_name
- self.assertIsInstance(c1, Column)
- c2 = df["col_name"]
- self.assertIsInstance(c2, Column)
- c3 = fun.col("col_name")
- self.assertIsInstance(c3, Column)
-
- # All Protos should be identical
- cp1 = c1.to_plan(None)
- cp2 = c2.to_plan(None)
- cp3 = c3.to_plan(None)
-
- self.assertIsNotNone(cp1)
- self.assertEqual(cp1, cp2)
- self.assertEqual(cp2, cp3)
-
- def test_null_literal(self):
- null_lit = fun.lit(None)
- null_lit_p = null_lit.to_plan(None)
- self.assertEqual(null_lit_p.literal.HasField("null"), True)
-
- def test_binary_literal(self):
- val = b"binary\0\0asas"
- bin_lit = fun.lit(val)
- bin_lit_p = bin_lit.to_plan(None)
- self.assertEqual(bin_lit_p.literal.binary, val)
-
- def test_uuid_literal(self):
- val = uuid.uuid4()
- with self.assertRaises(ValueError):
- fun.lit(val)
-
- def test_column_literals(self):
- df = self.connect.with_plan(p.Read("table"))
- lit_df = df.select(fun.lit(10))
- self.assertIsNotNone(lit_df._plan.to_proto(None))
-
- self.assertIsNotNone(fun.lit(10).to_plan(None))
- plan = fun.lit(10).to_plan(None)
- self.assertIs(plan.literal.integer, 10)
-
- plan = fun.lit(1 << 33).to_plan(None)
- self.assertEqual(plan.literal.long, 1 << 33)
-
- def test_numeric_literal_types(self):
- int_lit = fun.lit(10)
- float_lit = fun.lit(10.1)
- decimal_lit = fun.lit(decimal.Decimal(99))
-
- self.assertIsNotNone(int_lit.to_plan(None))
- self.assertIsNotNone(float_lit.to_plan(None))
- self.assertIsNotNone(decimal_lit.to_plan(None))
-
- def test_float_nan_inf(self):
- na_lit = fun.lit(float("nan"))
- self.assertIsNotNone(na_lit.to_plan(None))
-
- inf_lit = fun.lit(float("inf"))
- self.assertIsNotNone(inf_lit.to_plan(None))
-
- inf_lit = fun.lit(float("-inf"))
- self.assertIsNotNone(inf_lit.to_plan(None))
-
- def test_datetime_literal_types(self):
- """Test the different timestamp, date, and timedelta types."""
- datetime_lit = fun.lit(datetime.datetime.now())
-
- p = datetime_lit.to_plan(None)
- self.assertIsNotNone(datetime_lit.to_plan(None))
- self.assertGreater(p.literal.timestamp, 10000000000000)
-
- date_lit = fun.lit(datetime.date.today())
- time_delta = fun.lit(datetime.timedelta(days=1, seconds=2,
microseconds=3))
-
- self.assertIsNotNone(date_lit.to_plan(None))
- self.assertIsNotNone(time_delta.to_plan(None))
- # (24 * 3600 + 2) * 1000000 + 3
- self.assertEqual(86402000003,
time_delta.to_plan(None).literal.day_time_interval)
-
- def test_list_to_literal(self):
- """Test conversion of lists to literals"""
- empty_list = []
- single_type = [1, 2, 3, 4]
- multi_type = ["ooo", 1, "asas", 2.3]
-
- empty_list_lit = fun.lit(empty_list)
- single_type_lit = fun.lit(single_type)
- multi_type_lit = fun.lit(multi_type)
-
- p = empty_list_lit.to_plan(None)
- self.assertIsNotNone(p)
-
- p = single_type_lit.to_plan(None)
- self.assertIsNotNone(p)
-
- p = multi_type_lit.to_plan(None)
- self.assertIsNotNone(p)
-
- lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
- self.assertIsNotNone(lit_list_plan)
-
- def test_column_alias(self) -> None:
- # SPARK-40809: Support for Column Aliases
- col0 = fun.col("a").alias("martin")
- self.assertEqual("Column<'Alias(ColumnReference(a), (martin))'>",
str(col0))
-
- col0 = fun.col("a").alias("martin", metadata={"pii": True})
- plan = col0.to_plan(self.session.client)
- self.assertIsNotNone(plan)
- self.assertEqual(plan.alias.metadata, '{"pii": true}')
-
- def test_column_expressions(self):
- """Test a more complex combination of expressions and their
translation into
- the protobuf structure."""
- df = self.connect.with_plan(p.Read("table"))
-
- expr = fun.lit(10) < fun.lit(10)
- expr_plan = expr.to_plan(None)
- self.assertIsNotNone(expr_plan.unresolved_function)
- self.assertEqual(expr_plan.unresolved_function.function_name, "<")
-
- expr = df.id % fun.lit(10) == fun.lit(10)
- expr_plan = expr.to_plan(None)
- self.assertIsNotNone(expr_plan.unresolved_function)
- self.assertEqual(expr_plan.unresolved_function.function_name, "==")
-
- lit_fun = expr_plan.unresolved_function.arguments[1]
- self.assertIsInstance(lit_fun, ProtoExpression)
- self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
- self.assertEqual(lit_fun.literal.integer, 10)
-
- mod_fun = expr_plan.unresolved_function.arguments[0]
- self.assertIsInstance(mod_fun, ProtoExpression)
- self.assertIsInstance(mod_fun.unresolved_function,
ProtoExpression.UnresolvedFunction)
- self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
- self.assertIsInstance(mod_fun.unresolved_function.arguments[0],
ProtoExpression)
- self.assertIsInstance(
- mod_fun.unresolved_function.arguments[0].unresolved_attribute,
- ProtoExpression.UnresolvedAttribute,
- )
- self.assertEqual(
-
mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier,
"id"
- )
-
-
-if __name__ == "__main__":
- import unittest
- from pyspark.sql.tests.connect.test_connect_column_expressions import * #
noqa: F401
-
- try:
- import xmlrunner
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
b/python/pyspark/sql/tests/connect/test_connect_plan.py
similarity index 79%
rename from python/pyspark/sql/tests/connect/test_connect_plan_only.py
rename to python/pyspark/sql/tests/connect/test_connect_plan.py
index 5e3c6661e52..498c273e7f2 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -15,6 +15,9 @@
# limitations under the License.
#
import unittest
+import uuid
+import datetime
+import decimal
from pyspark.testing.connectutils import (
PlanOnlyTestFixture,
@@ -26,8 +29,9 @@ if should_test_connect:
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.dataframe import DataFrame
- from pyspark.sql.connect.plan import WriteOperation
+ from pyspark.sql.connect.plan import WriteOperation, Read
from pyspark.sql.connect.readwriter import DataFrameReader
+ from pyspark.sql.connect.functions import col, lit
from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import (
@@ -41,7 +45,7 @@ if should_test_connect:
@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
+class SparkConnectPlanTests(PlanOnlyTestFixture):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
@@ -327,6 +331,24 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
self.assertEqual(plan.root.freq_items.support, 0.01)
+ def test_freqItems(self):
+ df = self.connect.readTable(table_name=self.tbl_name)
+ plan = (
+ df.filter(df.col_name > 3).freqItems(["col_a", "col_b"],
1)._plan.to_proto(self.connect)
+ )
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 1)
+ plan = df.filter(df.col_name > 3).freqItems(["col_a",
"col_b"])._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 0.01)
+
+ plan = df.stat.freqItems(["col_a", "col_b"],
1)._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 1)
+ plan = df.stat.freqItems(["col_a",
"col_b"])._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 0.01)
+
def test_limit(self):
df = self.connect.readTable(table_name=self.tbl_name)
limit_plan = df.limit(10)._plan.to_proto(self.connect)
@@ -578,28 +600,6 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
new_plan = df.to(schema)._plan.to_proto(self.connect)
self.assertEqual(pyspark_types_to_proto_types(schema),
new_plan.root.to_schema.schema)
- def test_unsupported_functions(self):
- # SPARK-41225: Disable unsupported functions.
- df = self.connect.readTable(table_name=self.tbl_name)
- for f in (
- "rdd",
- "unpersist",
- "cache",
- "persist",
- "withWatermark",
- "observe",
- "foreach",
- "foreachPartition",
- "toLocalIterator",
- "checkpoint",
- "localCheckpoint",
- "_repr_html_",
- "semanticHash",
- "sameSemantics",
- ):
- with self.assertRaises(NotImplementedError):
- getattr(df, f)()
-
def test_write_operation(self):
wo = WriteOperation(self.connect.readTable("name")._plan)
wo.mode = "overwrite"
@@ -670,9 +670,184 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
for line in expected:
self.assertIn(line, actual)
+ def test_select_with_columns_and_strings(self):
+ df = self.connect.with_plan(Read("table"))
+
self.assertIsNotNone(df.select(col("name"))._plan.to_proto(self.connect))
+ self.assertIsNotNone(df.select("name"))
+ self.assertIsNotNone(df.select("name", "name2"))
+ self.assertIsNotNone(df.select(col("name"), col("name2")))
+ self.assertIsNotNone(df.select(col("name"), "name2"))
+ self.assertIsNotNone(df.select("*"))
+
+ def test_join_with_join_type(self):
+ df_left = self.connect.with_plan(Read("table"))
+ df_right = self.connect.with_plan(Read("table"))
+ for (join_type_str, join_type) in [
+ (None, proto.Join.JoinType.JOIN_TYPE_INNER),
+ ("inner", proto.Join.JoinType.JOIN_TYPE_INNER),
+ ("outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
+ ("leftouter", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
+ ("rightouter", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
+ ("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
+ ("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
+ ("cross", proto.Join.JoinType.JOIN_TYPE_CROSS),
+ ]:
+ joined_df = df_left.join(df_right, on=col("name"),
how=join_type_str)._plan.to_proto(
+ self.connect
+ )
+ self.assertEqual(joined_df.root.join.join_type, join_type)
+
+ def test_simple_column_expressions(self):
+ df = self.connect.with_plan(Read("table"))
+
+ c1 = df.col_name
+ self.assertIsInstance(c1, Column)
+ c2 = df["col_name"]
+ self.assertIsInstance(c2, Column)
+ c3 = col("col_name")
+ self.assertIsInstance(c3, Column)
+
+ # All Protos should be identical
+ cp1 = c1.to_plan(None)
+ cp2 = c2.to_plan(None)
+ cp3 = c3.to_plan(None)
+
+ self.assertIsNotNone(cp1)
+ self.assertEqual(cp1, cp2)
+ self.assertEqual(cp2, cp3)
+
+ def test_null_literal(self):
+ null_lit = lit(None)
+ null_lit_p = null_lit.to_plan(None)
+ self.assertEqual(null_lit_p.literal.HasField("null"), True)
+
+ def test_binary_literal(self):
+ val = b"binary\0\0asas"
+ bin_lit = lit(val)
+ bin_lit_p = bin_lit.to_plan(None)
+ self.assertEqual(bin_lit_p.literal.binary, val)
+
+ def test_uuid_literal(self):
+
+ val = uuid.uuid4()
+ with self.assertRaises(ValueError):
+ lit(val)
+
+ def test_column_literals(self):
+ df = self.connect.with_plan(Read("table"))
+ lit_df = df.select(lit(10))
+ self.assertIsNotNone(lit_df._plan.to_proto(None))
+
+ self.assertIsNotNone(lit(10).to_plan(None))
+ plan = lit(10).to_plan(None)
+ self.assertIs(plan.literal.integer, 10)
+
+ plan = lit(1 << 33).to_plan(None)
+ self.assertEqual(plan.literal.long, 1 << 33)
+
+ def test_numeric_literal_types(self):
+ int_lit = lit(10)
+ float_lit = lit(10.1)
+ decimal_lit = lit(decimal.Decimal(99))
+
+ self.assertIsNotNone(int_lit.to_plan(None))
+ self.assertIsNotNone(float_lit.to_plan(None))
+ self.assertIsNotNone(decimal_lit.to_plan(None))
+
+ def test_float_nan_inf(self):
+ na_lit = lit(float("nan"))
+ self.assertIsNotNone(na_lit.to_plan(None))
+
+ inf_lit = lit(float("inf"))
+ self.assertIsNotNone(inf_lit.to_plan(None))
+
+ inf_lit = lit(float("-inf"))
+ self.assertIsNotNone(inf_lit.to_plan(None))
+
+ def test_datetime_literal_types(self):
+ """Test the different timestamp, date, and timedelta types."""
+ datetime_lit = lit(datetime.datetime.now())
+
+ p = datetime_lit.to_plan(None)
+ self.assertIsNotNone(datetime_lit.to_plan(None))
+ self.assertGreater(p.literal.timestamp, 10000000000000)
+
+ date_lit = lit(datetime.date.today())
+ time_delta = lit(datetime.timedelta(days=1, seconds=2, microseconds=3))
+
+ self.assertIsNotNone(date_lit.to_plan(None))
+ self.assertIsNotNone(time_delta.to_plan(None))
+ # (24 * 3600 + 2) * 1000000 + 3
+ self.assertEqual(86402000003,
time_delta.to_plan(None).literal.day_time_interval)
+
+ def test_list_to_literal(self):
+ """Test conversion of lists to literals"""
+ empty_list = []
+ single_type = [1, 2, 3, 4]
+ multi_type = ["ooo", 1, "asas", 2.3]
+
+ empty_list_lit = lit(empty_list)
+ single_type_lit = lit(single_type)
+ multi_type_lit = lit(multi_type)
+
+ p = empty_list_lit.to_plan(None)
+ self.assertIsNotNone(p)
+
+ p = single_type_lit.to_plan(None)
+ self.assertIsNotNone(p)
+
+ p = multi_type_lit.to_plan(None)
+ self.assertIsNotNone(p)
+
+ lit_list_plan = lit([lit(10), lit("str")]).to_plan(None)
+ self.assertIsNotNone(lit_list_plan)
+
+ def test_column_alias(self) -> None:
+ # SPARK-40809: Support for Column Aliases
+ col0 = col("a").alias("martin")
+ self.assertEqual("Column<'Alias(ColumnReference(a), (martin))'>",
str(col0))
+
+ col0 = col("a").alias("martin", metadata={"pii": True})
+ plan = col0.to_plan(self.session.client)
+ self.assertIsNotNone(plan)
+ self.assertEqual(plan.alias.metadata, '{"pii": true}')
+
+ def test_column_expressions(self):
+ """Test a more complex combination of expressions and their
translation into
+ the protobuf structure."""
+ df = self.connect.with_plan(Read("table"))
+
+ expr = lit(10) < lit(10)
+ expr_plan = expr.to_plan(None)
+ self.assertIsNotNone(expr_plan.unresolved_function)
+ self.assertEqual(expr_plan.unresolved_function.function_name, "<")
+
+ expr = df.id % lit(10) == lit(10)
+ expr_plan = expr.to_plan(None)
+ self.assertIsNotNone(expr_plan.unresolved_function)
+ self.assertEqual(expr_plan.unresolved_function.function_name, "==")
+
+ lit_fun = expr_plan.unresolved_function.arguments[1]
+ self.assertIsInstance(lit_fun, proto.Expression)
+ self.assertIsInstance(lit_fun.literal, proto.Expression.Literal)
+ self.assertEqual(lit_fun.literal.integer, 10)
+
+ mod_fun = expr_plan.unresolved_function.arguments[0]
+ self.assertIsInstance(mod_fun, proto.Expression)
+ self.assertIsInstance(mod_fun.unresolved_function,
proto.Expression.UnresolvedFunction)
+ self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
+ self.assertIsInstance(mod_fun.unresolved_function.arguments[0],
proto.Expression)
+ self.assertIsInstance(
+ mod_fun.unresolved_function.arguments[0].unresolved_attribute,
+ proto.Expression.UnresolvedAttribute,
+ )
+ self.assertEqual(
+
mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier,
"id"
+ )
+
if __name__ == "__main__":
- from pyspark.sql.tests.connect.test_connect_plan_only import * # noqa:
F401
+ from pyspark.sql.tests.connect.test_connect_plan import * # noqa: F401
try:
import xmlrunner # type: ignore
diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py
b/python/pyspark/sql/tests/connect/test_connect_select_ops.py
deleted file mode 100644
index 7f8153f7fca..00000000000
--- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py
+++ /dev/null
@@ -1,71 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import unittest
-
-from pyspark.testing.connectutils import (
- PlanOnlyTestFixture,
- should_test_connect,
- connect_requirement_message,
-)
-
-if should_test_connect:
- from pyspark.sql.connect.functions import col
- from pyspark.sql.connect.plan import Read
- import pyspark.sql.connect.proto as proto
-
-
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectToProtoSuite(PlanOnlyTestFixture):
- def test_select_with_columns_and_strings(self):
- df = self.connect.with_plan(Read("table"))
-
self.assertIsNotNone(df.select(col("name"))._plan.to_proto(self.connect))
- self.assertIsNotNone(df.select("name"))
- self.assertIsNotNone(df.select("name", "name2"))
- self.assertIsNotNone(df.select(col("name"), col("name2")))
- self.assertIsNotNone(df.select(col("name"), "name2"))
- self.assertIsNotNone(df.select("*"))
-
- def test_join_with_join_type(self):
- df_left = self.connect.with_plan(Read("table"))
- df_right = self.connect.with_plan(Read("table"))
- for (join_type_str, join_type) in [
- (None, proto.Join.JoinType.JOIN_TYPE_INNER),
- ("inner", proto.Join.JoinType.JOIN_TYPE_INNER),
- ("outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
- ("leftouter", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
- ("rightouter", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
- ("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
- ("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
- ("cross", proto.Join.JoinType.JOIN_TYPE_CROSS),
- ]:
- joined_df = df_left.join(df_right, on=col("name"),
how=join_type_str)._plan.to_proto(
- self.connect
- )
- self.assertEqual(joined_df.root.join.join_type, join_type)
-
-
-if __name__ == "__main__":
- import unittest
- from pyspark.sql.tests.connect.test_connect_select_ops import * # noqa:
F401
-
- try:
- import xmlrunner # type: ignore
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]