This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 714feb7b5aec [SPARK-47377][PYTHON][CONNECT][TESTS][FOLLOWUP] Factor out more tests from `SparkConnectSQLTestCase` 714feb7b5aec is described below commit 714feb7b5aec948c3499ad486dc63ea89a241ffa Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Mar 16 21:17:03 2024 +0900 [SPARK-47377][PYTHON][CONNECT][TESTS][FOLLOWUP] Factor out more tests from `SparkConnectSQLTestCase` ### What changes were proposed in this pull request? It is a fallowup of https://github.com/apache/spark/pull/45497 , and factor out more tests from `SparkConnectSQLTestCase` It should also be the laster PR for this test :) ### Why are the changes needed? for testing parallelism ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #45533 from zhengruifeng/continue_break_basic. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- dev/sparktestsupport/modules.py | 2 + .../sql/tests/connect/test_connect_basic.py | 1041 +------------------- .../sql/tests/connect/test_connect_creation.py | 19 +- .../sql/tests/connect/test_connect_error.py | 230 +++++ .../sql/tests/connect/test_connect_readwriter.py | 63 ++ .../pyspark/sql/tests/connect/test_connect_stat.py | 813 +++++++++++++++ 6 files changed, 1128 insertions(+), 1040 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index ef4d28bce4c4..2dd67989035d 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1001,12 +1001,14 @@ pyspark_connect = Module( # sql unittests "pyspark.sql.tests.connect.test_connect_plan", "pyspark.sql.tests.connect.test_connect_basic", + "pyspark.sql.tests.connect.test_connect_error", "pyspark.sql.tests.connect.test_connect_function", "pyspark.sql.tests.connect.test_connect_collection", "pyspark.sql.tests.connect.test_connect_column", "pyspark.sql.tests.connect.test_connect_creation", "pyspark.sql.tests.connect.test_connect_readwriter", "pyspark.sql.tests.connect.test_connect_session", + "pyspark.sql.tests.connect.test_connect_stat", "pyspark.sql.tests.connect.test_parity_arrow", "pyspark.sql.tests.connect.test_parity_arrow_python_udf", "pyspark.sql.tests.connect.test_parity_datasources", diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 108201e5b927..4776851ba73d 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -20,12 +20,7 @@ import unittest import shutil import tempfile -from pyspark.errors import ( - PySparkAttributeError, - PySparkTypeError, - PySparkValueError, -) -from pyspark.errors.exceptions.base import SessionNotSameException +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import SparkSession as PySparkSession, Row from pyspark.sql.types import ( StructType, @@ -37,13 +32,7 @@ from pyspark.sql.types import ( ArrayType, Row, ) - -from pyspark.testing.sqlutils import ( - SQLTestUtils, - PythonOnlyUDT, - ExamplePoint, - PythonOnlyPoint, -) +from pyspark.testing.sqlutils import SQLTestUtils from pyspark.testing.connectutils import ( should_test_connect, ReusedConnectTestCase, @@ -56,7 +45,6 @@ from pyspark.errors.exceptions.connect import ( if should_test_connect: from pyspark.sql.connect.proto import Expression as ProtoExpression - from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect.column import Column from pyspark.sql.dataframe import DataFrame from pyspark.sql.connect.dataframe import DataFrame as CDataFrame @@ -142,19 +130,6 @@ class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSpark class SparkConnectBasicTests(SparkConnectSQLTestCase): - def test_recursion_handling_for_plan_logging(self): - """SPARK-45852 - Test that we can handle recursion in plan logging.""" - cdf = self.connect.range(1) - for x in range(400): - cdf = cdf.withColumn(f"col_{x}", CF.lit(x)) - - # Calling schema will trigger logging the message that will in turn trigger the message - # conversion into protobuf that will then trigger the recursion error. - self.assertIsNotNone(cdf.schema) - - result = self.connect._client._proto_to_string(cdf._plan.to_proto(self.connect._client)) - self.assertIn("recursion", result) - def test_df_getattr_behavior(self): cdf = self.connect.range(10) sdf = self.spark.range(10) @@ -259,12 +234,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): }, ) - def test_error_handling(self): - # SPARK-41533 Proper error handling for Spark Connect - df = self.connect.range(10).select("id2") - with self.assertRaises(AnalysisException): - df.collect() - def test_join_condition_column_list_columns(self): left_connect_df = self.connect.read.table(self.tbl_name) right_connect_df = self.connect.read.table(self.tbl_name2) @@ -351,73 +320,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(sdf3.schema, cdf3.schema) self.assertEqual(sdf3.collect(), cdf3.collect()) - def test_invalid_column(self): - # SPARK-41812: fail df1.select(df2.col) - data1 = [Row(a=1, b=2, c=3)] - cdf1 = self.connect.createDataFrame(data1) - - data2 = [Row(a=2, b=0)] - cdf2 = self.connect.createDataFrame(data2) - - with self.assertRaises(AnalysisException): - cdf1.select(cdf2.a).schema - - with self.assertRaises(AnalysisException): - cdf2.withColumn("x", cdf1.a + 1).schema - - # Can find the target plan node, but fail to resolve with it - with self.assertRaisesRegex( - AnalysisException, - "UNRESOLVED_COLUMN.WITH_SUGGESTION", - ): - cdf3 = cdf1.select(cdf1.a) - cdf3.select(cdf1.b).schema - - # Can not find the target plan node by plan id - with self.assertRaisesRegex( - AnalysisException, - "CANNOT_RESOLVE_DATAFRAME_COLUMN", - ): - cdf1.select(cdf2.a).schema - - def test_invalid_star(self): - data1 = [Row(a=1, b=2, c=3)] - cdf1 = self.connect.createDataFrame(data1) - - data2 = [Row(a=2, b=0)] - cdf2 = self.connect.createDataFrame(data2) - - # Can find the target plan node, but fail to resolve with it - with self.assertRaisesRegex( - AnalysisException, - "CANNOT_RESOLVE_DATAFRAME_COLUMN", - ): - cdf3 = cdf1.select(cdf1.a) - cdf3.select(cdf1["*"]).schema - - # Can find the target plan node, but fail to resolve with it - with self.assertRaisesRegex( - AnalysisException, - "CANNOT_RESOLVE_DATAFRAME_COLUMN", - ): - # column 'a has been replaced - cdf3 = cdf1.withColumn("a", CF.lit(0)) - cdf3.select(cdf1["*"]).schema - - # Can not find the target plan node by plan id - with self.assertRaisesRegex( - AnalysisException, - "CANNOT_RESOLVE_DATAFRAME_COLUMN", - ): - cdf1.select(cdf2["*"]).schema - - # cdf1["*"] exists on both side - with self.assertRaisesRegex( - AnalysisException, - "AMBIGUOUS_COLUMN_REFERENCE", - ): - cdf1.join(cdf1).select(cdf1["*"]).schema - def test_with_columns_renamed(self): # SPARK-41312: test DataFrame.withColumnsRenamed() self.assertEqual( @@ -718,14 +620,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas() ) - def test_deduplicate_within_watermark_in_batch(self): - df = self.connect.read.table(self.tbl_name) - with self.assertRaisesRegex( - AnalysisException, - "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets", - ): - df.dropDuplicatesWithinWatermark().toPandas() - def test_drop(self): # SPARK-41169: test drop query = """ @@ -899,139 +793,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): sdf.select("a", "b.*").collect(), ) - def test_fill_na(self): - # SPARK-41128: Test fill na - query = """ - SELECT * FROM VALUES - (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) - AS tab(a, b, c) - """ - # +-----+----+----+ - # | a| b| c| - # +-----+----+----+ - # |false| 1|NULL| - # |false|NULL| 2.0| - # | NULL| 3| 3.0| - # +-----+----+----+ - - self.assert_eq( - self.connect.sql(query).fillna(True).toPandas(), - self.spark.sql(query).fillna(True).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).fillna(2).toPandas(), - self.spark.sql(query).fillna(2).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).fillna(2, ["a", "b"]).toPandas(), - self.spark.sql(query).fillna(2, ["a", "b"]).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).na.fill({"a": True, "b": 2}).toPandas(), - self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(), - ) - - def test_drop_na(self): - # SPARK-41148: Test drop na - query = """ - SELECT * FROM VALUES - (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) - AS tab(a, b, c) - """ - # +-----+----+----+ - # | a| b| c| - # +-----+----+----+ - # |false| 1|NULL| - # |false|NULL| 2.0| - # | NULL| 3| 3.0| - # +-----+----+----+ - - self.assert_eq( - self.connect.sql(query).dropna().toPandas(), - self.spark.sql(query).dropna().toPandas(), - ) - self.assert_eq( - self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(), - self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(), - self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(), - self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(), - ) - - def test_replace(self): - # SPARK-41315: Test replace - query = """ - SELECT * FROM VALUES - (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) - AS tab(a, b, c) - """ - # +-----+----+----+ - # | a| b| c| - # +-----+----+----+ - # |false| 1|NULL| - # |false|NULL| 2.0| - # | NULL| 3| 3.0| - # +-----+----+----+ - - self.assert_eq( - self.connect.sql(query).replace(2, 3).toPandas(), - self.spark.sql(query).replace(2, 3).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).na.replace(False, True).toPandas(), - self.spark.sql(query).na.replace(False, True).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(), - self.spark.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).na.replace((1, 2), (3, 1)).toPandas(), - self.spark.sql(query).na.replace((1, 2), (3, 1)).toPandas(), - ) - self.assert_eq( - self.connect.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(), - self.spark.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(), - ) - - with self.assertRaises(ValueError) as context: - self.connect.sql(query).replace({None: 1}, subset="a").toPandas() - self.assertTrue("Mixed type replacements are not supported" in str(context.exception)) - - with self.assertRaises(AnalysisException) as context: - self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas() - self.assertIn( - """Cannot resolve column name "x" among (a, b, c)""", str(context.exception) - ) - - def test_unpivot(self): - self.assert_eq( - self.connect.read.table(self.tbl_name) - .filter("id > 3") - .unpivot(["id"], ["name"], "variable", "value") - .toPandas(), - self.spark.read.table(self.tbl_name) - .filter("id > 3") - .unpivot(["id"], ["name"], "variable", "value") - .toPandas(), - ) - - self.assert_eq( - self.connect.read.table(self.tbl_name) - .filter("id > 3") - .unpivot("id", None, "variable", "value") - .toPandas(), - self.spark.read.table(self.tbl_name) - .filter("id > 3") - .unpivot("id", None, "variable", "value") - .toPandas(), - ) - def test_union_by_name(self): # SPARK-41832: Test unionByName data1 = [(1, 2, 3)] @@ -1054,21 +815,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas()) - def test_random_split(self): - # SPARK-41440: test randomSplit(weights, seed). - relations = ( - self.connect.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2) - ) - datasets = ( - self.spark.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2) - ) - - self.assertTrue(len(relations) == len(datasets)) - i = 0 - while i < len(relations): - self.assert_eq(relations[i].toPandas(), datasets[i].toPandas()) - i += 1 - def test_observe(self): # SPARK-41527: test DataFrame.observe() observation_name = "my_metric" @@ -1207,36 +953,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): "ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), "name")._explain_string() ) - def test_different_spark_session_join_or_union(self): - df = self.connect.range(10).limit(3) - - spark2 = RemoteSparkSession(connection="sc://localhost") - df2 = spark2.range(10).limit(3) - - with self.assertRaises(SessionNotSameException) as e1: - df.union(df2).collect() - self.check_error( - exception=e1.exception, - error_class="SESSION_NOT_SAME", - message_parameters={}, - ) - - with self.assertRaises(SessionNotSameException) as e2: - df.unionByName(df2).collect() - self.check_error( - exception=e2.exception, - error_class="SESSION_NOT_SAME", - message_parameters={}, - ) - - with self.assertRaises(SessionNotSameException) as e3: - df.join(df2).collect() - self.check_error( - exception=e3.exception, - error_class="SESSION_NOT_SAME", - message_parameters={}, - ) - def test_extended_hint_types(self): cdf = self.connect.range(100).toDF("id") @@ -1301,209 +1017,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n" self.assertEqual(show_str, expected) - def test_describe(self): - # SPARK-41403: Test the describe method - self.assert_eq( - self.connect.read.table(self.tbl_name).describe("id").toPandas(), - self.spark.read.table(self.tbl_name).describe("id").toPandas(), - ) - self.assert_eq( - self.connect.read.table(self.tbl_name).describe("id", "name").toPandas(), - self.spark.read.table(self.tbl_name).describe("id", "name").toPandas(), - ) - self.assert_eq( - self.connect.read.table(self.tbl_name).describe(["id", "name"]).toPandas(), - self.spark.read.table(self.tbl_name).describe(["id", "name"]).toPandas(), - ) - - def test_stat_cov(self): - # SPARK-41067: Test the stat.cov method - self.assertEqual( - self.connect.read.table(self.tbl_name2).stat.cov("col1", "col3"), - self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"), - ) - - def test_stat_corr(self): - # SPARK-41068: Test the stat.corr method - self.assertEqual( - self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"), - self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"), - ) - - self.assertEqual( - self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"), - self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"), - ) - - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name2).stat.corr(1, "col3", "pearson") - - self.check_error( - exception=pe.exception, - error_class="NOT_STR", - message_parameters={ - "arg_name": "col1", - "arg_type": "int", - }, - ) - - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name).stat.corr("col1", 1, "pearson") - - self.check_error( - exception=pe.exception, - error_class="NOT_STR", - message_parameters={ - "arg_name": "col2", - "arg_type": "int", - }, - ) - with self.assertRaises(ValueError) as context: - self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "spearman"), - self.assertTrue( - "Currently only the calculation of the Pearson Correlation " - + "coefficient is supported." - in str(context.exception) - ) - - def test_stat_approx_quantile(self): - # SPARK-41069: Test the stat.approxQuantile method - result = self.connect.read.table(self.tbl_name2).stat.approxQuantile( - ["col1", "col3"], [0.1, 0.5, 0.9], 0.1 - ) - self.assertEqual(len(result), 2) - self.assertEqual(len(result[0]), 3) - self.assertEqual(len(result[1]), 3) - - result = self.connect.read.table(self.tbl_name2).stat.approxQuantile( - ["col1"], [0.1, 0.5, 0.9], 0.1 - ) - self.assertEqual(len(result), 1) - self.assertEqual(len(result[0]), 3) - - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name2).stat.approxQuantile(1, [0.1, 0.5, 0.9], 0.1) - - self.check_error( - exception=pe.exception, - error_class="NOT_LIST_OR_STR_OR_TUPLE", - message_parameters={ - "arg_name": "col", - "arg_type": "int", - }, - ) - - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name2).stat.approxQuantile(["col1", "col3"], 0.1, 0.1) - - self.check_error( - exception=pe.exception, - error_class="NOT_LIST_OR_TUPLE", - message_parameters={ - "arg_name": "probabilities", - "arg_type": "float", - }, - ) - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name2).stat.approxQuantile( - ["col1", "col3"], [-0.1], 0.1 - ) - - self.check_error( - exception=pe.exception, - error_class="NOT_LIST_OF_FLOAT_OR_INT", - message_parameters={"arg_name": "probabilities", "arg_type": "float"}, - ) - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name2).stat.approxQuantile( - ["col1", "col3"], [0.1, 0.5, 0.9], "str" - ) - - self.check_error( - exception=pe.exception, - error_class="NOT_FLOAT_OR_INT", - message_parameters={ - "arg_name": "relativeError", - "arg_type": "str", - }, - ) - with self.assertRaises(PySparkValueError) as pe: - self.connect.read.table(self.tbl_name2).stat.approxQuantile( - ["col1", "col3"], [0.1, 0.5, 0.9], -0.1 - ) - - self.check_error( - exception=pe.exception, - error_class="NEGATIVE_VALUE", - message_parameters={ - "arg_name": "relativeError", - "arg_value": "-0.1", - }, - ) - - def test_stat_freq_items(self): - # SPARK-41065: Test the stat.freqItems method - self.assert_eq( - self.connect.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(), - self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(), - check_exact=False, - ) - - self.assert_eq( - self.connect.read.table(self.tbl_name2) - .stat.freqItems(["col1", "col3"], 0.4) - .toPandas(), - self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"], 0.4).toPandas(), - ) - - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name2).stat.freqItems("col1") - - self.check_error( - exception=pe.exception, - error_class="NOT_LIST_OR_TUPLE", - message_parameters={ - "arg_name": "cols", - "arg_type": "str", - }, - ) - - def test_stat_sample_by(self): - # SPARK-41069: Test stat.sample_by - - cdf = self.connect.range(0, 100).select((CF.col("id") % 3).alias("key")) - sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key")) - - self.assert_eq( - cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 0.2}, seed=0) - .groupBy("key") - .agg(CF.count(CF.lit(1))) - .orderBy("key") - .toPandas(), - sdf.sampleBy(sdf.key, fractions={0: 0.1, 1: 0.2}, seed=0) - .groupBy("key") - .agg(SF.count(SF.lit(1))) - .orderBy("key") - .toPandas(), - ) - - with self.assertRaises(PySparkTypeError) as pe: - cdf.stat.sampleBy(cdf.key, fractions={0: 0.1, None: 0.2}, seed=0) - - self.check_error( - exception=pe.exception, - error_class="DISALLOWED_TYPE_FOR_CONTAINER", - message_parameters={ - "arg_name": "fractions", - "arg_type": "dict", - "allowed_types": "float, int, str", - "item_type": "NoneType", - }, - ) - - with self.assertRaises(SparkConnectException): - cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 1.2}, seed=0).show() - def test_repr(self): # SPARK-41213: Test the __repr__ method query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)""" @@ -1625,57 +1138,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas() - def test_agg_with_two_agg_exprs(self) -> None: - # SPARK-41230: test dataframe.agg() - self.assert_eq( - self.connect.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(), - self.spark.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(), - ) - - def test_subtract(self): - # SPARK-41453: test dataframe.subtract() - ndf1 = self.connect.read.table(self.tbl_name) - ndf2 = ndf1.filter("id > 3") - df1 = self.spark.read.table(self.tbl_name) - df2 = df1.filter("id > 3") - - self.assert_eq( - ndf1.subtract(ndf2).toPandas(), - df1.subtract(df2).toPandas(), - ) - - def test_agg_with_avg(self): - # SPARK-41325: groupby.avg() - df = ( - self.connect.range(10) - .groupBy((CF.col("id") % CF.lit(2)).alias("moded")) - .avg("id") - .sort("moded") - ) - res = df.collect() - self.assertEqual(2, len(res)) - self.assertEqual(4.0, res[0][1]) - self.assertEqual(5.0, res[1][1]) - - # Additional GroupBy tests with 3 rows - - df_a = self.connect.range(10).groupBy((CF.col("id") % CF.lit(3)).alias("moded")) - df_b = self.spark.range(10).groupBy((SF.col("id") % SF.lit(3)).alias("moded")) - self.assertEqual( - set(df_b.agg(SF.sum("id")).collect()), set(df_a.agg(CF.sum("id")).collect()) - ) - - # Dict agg - measures = {"id": "sum"} - self.assertEqual( - set(df_a.agg(measures).select("sum(id)").collect()), - set(df_b.agg(measures).select("sum(id)").collect()), - ) - - def test_column_cannot_be_constructed_from_string(self): - with self.assertRaises(TypeError): - Column("col") - def test_crossjoin(self): # SPARK-41227: Test CrossJoin connect_df = self.connect.read.table(self.tbl_name) @@ -1693,376 +1155,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()), ) - def test_grouped_data(self): - query = """ - SELECT * FROM VALUES - ('James', 'Sales', 3000, 2020), - ('Michael', 'Sales', 4600, 2020), - ('Robert', 'Sales', 4100, 2020), - ('Maria', 'Finance', 3000, 2020), - ('James', 'Sales', 3000, 2019), - ('Scott', 'Finance', 3300, 2020), - ('Jen', 'Finance', 3900, 2020), - ('Jeff', 'Marketing', 3000, 2020), - ('Kumar', 'Marketing', 2000, 2020), - ('Saif', 'Sales', 4100, 2020) - AS T(name, department, salary, year) - """ - - # +-------+----------+------+----+ - # | name|department|salary|year| - # +-------+----------+------+----+ - # | James| Sales| 3000|2020| - # |Michael| Sales| 4600|2020| - # | Robert| Sales| 4100|2020| - # | Maria| Finance| 3000|2020| - # | James| Sales| 3000|2019| - # | Scott| Finance| 3300|2020| - # | Jen| Finance| 3900|2020| - # | Jeff| Marketing| 3000|2020| - # | Kumar| Marketing| 2000|2020| - # | Saif| Sales| 4100|2020| - # +-------+----------+------+----+ - - cdf = self.connect.sql(query) - sdf = self.spark.sql(query) - - # test groupby - self.assert_eq( - cdf.groupBy("name").agg(CF.sum(cdf.salary)).toPandas(), - sdf.groupBy("name").agg(SF.sum(sdf.salary)).toPandas(), - ) - self.assert_eq( - cdf.groupBy("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), - sdf.groupBy("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), - ) - - # test rollup - self.assert_eq( - cdf.rollup("name").agg(CF.sum(cdf.salary)).toPandas(), - sdf.rollup("name").agg(SF.sum(sdf.salary)).toPandas(), - ) - self.assert_eq( - cdf.rollup("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), - sdf.rollup("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), - ) - - # test cube - self.assert_eq( - cdf.cube("name").agg(CF.sum(cdf.salary)).toPandas(), - sdf.cube("name").agg(SF.sum(sdf.salary)).toPandas(), - ) - self.assert_eq( - cdf.cube("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), - sdf.cube("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), - ) - - # test pivot - # pivot with values - self.assert_eq( - cdf.groupBy("name") - .pivot("department", ["Sales", "Marketing"]) - .agg(CF.sum(cdf.salary)) - .toPandas(), - sdf.groupBy("name") - .pivot("department", ["Sales", "Marketing"]) - .agg(SF.sum(sdf.salary)) - .toPandas(), - ) - self.assert_eq( - cdf.groupBy(cdf.name) - .pivot("department", ["Sales", "Finance", "Marketing"]) - .agg(CF.sum(cdf.salary)) - .toPandas(), - sdf.groupBy(sdf.name) - .pivot("department", ["Sales", "Finance", "Marketing"]) - .agg(SF.sum(sdf.salary)) - .toPandas(), - ) - self.assert_eq( - cdf.groupBy(cdf.name) - .pivot("department", ["Sales", "Finance", "Unknown"]) - .agg(CF.sum(cdf.salary)) - .toPandas(), - sdf.groupBy(sdf.name) - .pivot("department", ["Sales", "Finance", "Unknown"]) - .agg(SF.sum(sdf.salary)) - .toPandas(), - ) - - # pivot without values - self.assert_eq( - cdf.groupBy("name").pivot("department").agg(CF.sum(cdf.salary)).toPandas(), - sdf.groupBy("name").pivot("department").agg(SF.sum(sdf.salary)).toPandas(), - ) - - self.assert_eq( - cdf.groupBy("name").pivot("year").agg(CF.sum(cdf.salary)).toPandas(), - sdf.groupBy("name").pivot("year").agg(SF.sum(sdf.salary)).toPandas(), - ) - - # check error - with self.assertRaisesRegex( - Exception, - "PIVOT after ROLLUP is not supported", - ): - cdf.rollup("name").pivot("department").agg(CF.sum(cdf.salary)) - - with self.assertRaisesRegex( - Exception, - "PIVOT after CUBE is not supported", - ): - cdf.cube("name").pivot("department").agg(CF.sum(cdf.salary)) - - with self.assertRaisesRegex( - Exception, - "Repeated PIVOT operation is not supported", - ): - cdf.groupBy("name").pivot("year").pivot("year").agg(CF.sum(cdf.salary)) - - with self.assertRaises(PySparkTypeError) as pe: - cdf.groupBy("name").pivot("department", ["Sales", b"Marketing"]).agg(CF.sum(cdf.salary)) - - self.check_error( - exception=pe.exception, - error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR", - message_parameters={ - "arg_name": "value", - "arg_type": "bytes", - }, - ) - - def test_numeric_aggregation(self): - # SPARK-41737: test numeric aggregation - query = """ - SELECT * FROM VALUES - ('James', 'Sales', 3000, 2020), - ('Michael', 'Sales', 4600, 2020), - ('Robert', 'Sales', 4100, 2020), - ('Maria', 'Finance', 3000, 2020), - ('James', 'Sales', 3000, 2019), - ('Scott', 'Finance', 3300, 2020), - ('Jen', 'Finance', 3900, 2020), - ('Jeff', 'Marketing', 3000, 2020), - ('Kumar', 'Marketing', 2000, 2020), - ('Saif', 'Sales', 4100, 2020) - AS T(name, department, salary, year) - """ - - # +-------+----------+------+----+ - # | name|department|salary|year| - # +-------+----------+------+----+ - # | James| Sales| 3000|2020| - # |Michael| Sales| 4600|2020| - # | Robert| Sales| 4100|2020| - # | Maria| Finance| 3000|2020| - # | James| Sales| 3000|2019| - # | Scott| Finance| 3300|2020| - # | Jen| Finance| 3900|2020| - # | Jeff| Marketing| 3000|2020| - # | Kumar| Marketing| 2000|2020| - # | Saif| Sales| 4100|2020| - # +-------+----------+------+----+ - - cdf = self.connect.sql(query) - sdf = self.spark.sql(query) - - # test groupby - self.assert_eq( - cdf.groupBy("name").min().toPandas(), - sdf.groupBy("name").min().toPandas(), - ) - self.assert_eq( - cdf.groupBy("name").min("salary").toPandas(), - sdf.groupBy("name").min("salary").toPandas(), - ) - self.assert_eq( - cdf.groupBy("name").max("salary").toPandas(), - sdf.groupBy("name").max("salary").toPandas(), - ) - self.assert_eq( - cdf.groupBy("name", cdf.department).avg("salary", "year").toPandas(), - sdf.groupBy("name", sdf.department).avg("salary", "year").toPandas(), - ) - self.assert_eq( - cdf.groupBy("name", cdf.department).mean("salary", "year").toPandas(), - sdf.groupBy("name", sdf.department).mean("salary", "year").toPandas(), - ) - self.assert_eq( - cdf.groupBy("name", cdf.department).sum("salary", "year").toPandas(), - sdf.groupBy("name", sdf.department).sum("salary", "year").toPandas(), - ) - - # test rollup - self.assert_eq( - cdf.rollup("name").max().toPandas(), - sdf.rollup("name").max().toPandas(), - ) - self.assert_eq( - cdf.rollup("name").min("salary").toPandas(), - sdf.rollup("name").min("salary").toPandas(), - ) - self.assert_eq( - cdf.rollup("name").max("salary").toPandas(), - sdf.rollup("name").max("salary").toPandas(), - ) - self.assert_eq( - cdf.rollup("name", cdf.department).avg("salary", "year").toPandas(), - sdf.rollup("name", sdf.department).avg("salary", "year").toPandas(), - ) - self.assert_eq( - cdf.rollup("name", cdf.department).mean("salary", "year").toPandas(), - sdf.rollup("name", sdf.department).mean("salary", "year").toPandas(), - ) - self.assert_eq( - cdf.rollup("name", cdf.department).sum("salary", "year").toPandas(), - sdf.rollup("name", sdf.department).sum("salary", "year").toPandas(), - ) - - # test cube - self.assert_eq( - cdf.cube("name").avg().toPandas(), - sdf.cube("name").avg().toPandas(), - ) - self.assert_eq( - cdf.cube("name").mean().toPandas(), - sdf.cube("name").mean().toPandas(), - ) - self.assert_eq( - cdf.cube("name").min("salary").toPandas(), - sdf.cube("name").min("salary").toPandas(), - ) - self.assert_eq( - cdf.cube("name").max("salary").toPandas(), - sdf.cube("name").max("salary").toPandas(), - ) - self.assert_eq( - cdf.cube("name", cdf.department).avg("salary", "year").toPandas(), - sdf.cube("name", sdf.department).avg("salary", "year").toPandas(), - ) - self.assert_eq( - cdf.cube("name", cdf.department).sum("salary", "year").toPandas(), - sdf.cube("name", sdf.department).sum("salary", "year").toPandas(), - ) - - # test pivot - # pivot with values - self.assert_eq( - cdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(), - sdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(), - ) - self.assert_eq( - cdf.groupBy("name") - .pivot("department", ["Sales", "Marketing"]) - .min("salary") - .toPandas(), - sdf.groupBy("name") - .pivot("department", ["Sales", "Marketing"]) - .min("salary") - .toPandas(), - ) - self.assert_eq( - cdf.groupBy("name") - .pivot("department", ["Sales", "Marketing"]) - .max("salary") - .toPandas(), - sdf.groupBy("name") - .pivot("department", ["Sales", "Marketing"]) - .max("salary") - .toPandas(), - ) - self.assert_eq( - cdf.groupBy(cdf.name) - .pivot("department", ["Sales", "Finance", "Unknown"]) - .avg("salary", "year") - .toPandas(), - sdf.groupBy(sdf.name) - .pivot("department", ["Sales", "Finance", "Unknown"]) - .avg("salary", "year") - .toPandas(), - ) - self.assert_eq( - cdf.groupBy(cdf.name) - .pivot("department", ["Sales", "Finance", "Unknown"]) - .sum("salary", "year") - .toPandas(), - sdf.groupBy(sdf.name) - .pivot("department", ["Sales", "Finance", "Unknown"]) - .sum("salary", "year") - .toPandas(), - ) - - # pivot without values - self.assert_eq( - cdf.groupBy("name").pivot("department").min().toPandas(), - sdf.groupBy("name").pivot("department").min().toPandas(), - ) - self.assert_eq( - cdf.groupBy("name").pivot("department").min("salary").toPandas(), - sdf.groupBy("name").pivot("department").min("salary").toPandas(), - ) - self.assert_eq( - cdf.groupBy("name").pivot("department").max("salary").toPandas(), - sdf.groupBy("name").pivot("department").max("salary").toPandas(), - ) - self.assert_eq( - cdf.groupBy(cdf.name).pivot("department").avg("salary", "year").toPandas(), - sdf.groupBy(sdf.name).pivot("department").avg("salary", "year").toPandas(), - ) - self.assert_eq( - cdf.groupBy(cdf.name).pivot("department").sum("salary", "year").toPandas(), - sdf.groupBy(sdf.name).pivot("department").sum("salary", "year").toPandas(), - ) - - # check error - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.groupBy("name").min("department").show() - - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.groupBy("name").max("salary", "department").show() - - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.rollup("name").avg("department").show() - - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.rollup("name").sum("salary", "department").show() - - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.cube("name").min("department").show() - - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.cube("name").max("salary", "department").show() - - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.groupBy("name").pivot("department").avg("department").show() - - with self.assertRaisesRegex( - TypeError, - "Numeric aggregation function can only be applied on numeric columns", - ): - cdf.groupBy("name").pivot("department").sum("salary", "department").show() - def test_with_metadata(self): cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"]) self.assertEqual(cdf.schema["age"].metadata, {}) @@ -2086,78 +1178,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): }, ) - def test_simple_udt(self): - from pyspark.ml.linalg import MatrixUDT, VectorUDT - - for schema in [ - StructType().add("key", LongType()).add("val", PythonOnlyUDT()), - StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())), - StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())), - StructType().add("key", LongType()).add("val", PythonOnlyUDT()), - StructType().add("key", LongType()).add("vec", VectorUDT()), - StructType().add("key", LongType()).add("mat", MatrixUDT()), - ]: - cdf = self.connect.createDataFrame(data=[], schema=schema) - sdf = self.spark.createDataFrame(data=[], schema=schema) - - self.assertEqual(cdf.schema, sdf.schema) - - def test_simple_udt_from_read(self): - from pyspark.ml.linalg import Matrices, Vectors - - with tempfile.TemporaryDirectory(prefix="test_simple_udt_from_read") as d: - path1 = f"{d}/df1.parquet" - self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()), - ).write.parquet(path1) - - path2 = f"{d}/df2.parquet" - self.spark.createDataFrame( - [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], - schema=StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())), - ).write.parquet(path2) - - path3 = f"{d}/df3.parquet" - self.spark.createDataFrame( - [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], - schema=StructType() - .add("key", LongType()) - .add("val", MapType(LongType(), PythonOnlyUDT())), - ).write.parquet(path3) - - path4 = f"{d}/df4.parquet" - self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()), - ).write.parquet(path4) - - path5 = f"{d}/df5.parquet" - self.spark.createDataFrame( - [Row(label=1.0, point=ExamplePoint(1.0, 2.0))] - ).write.parquet(path5) - - path6 = f"{d}/df6.parquet" - self.spark.createDataFrame( - [(Vectors.dense(1.0, 2.0, 3.0),), (Vectors.sparse(3, {1: 1.0, 2: 5.5}),)], - ["vec"], - ).write.parquet(path6) - - path7 = f"{d}/df7.parquet" - self.spark.createDataFrame( - [ - (Matrices.dense(3, 2, [0, 1, 4, 5, 9, 10]),), - (Matrices.sparse(1, 1, [0, 1], [0], [2.0]),), - ], - ["mat"], - ).write.parquet(path7) - - for path in [path1, path2, path3, path4, path5, path6, path7]: - self.assertEqual( - self.connect.read.parquet(path).schema, - self.spark.read.parquet(path).schema, - ) - def test_version(self): self.assertEqual( self.connect.version, @@ -2177,69 +1197,12 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): other.semanticHash(), ) - def test_unsupported_functions(self): - # SPARK-41225: Disable unsupported functions. - df = self.connect.read.table(self.tbl_name) - for f in ( - "checkpoint", - "localCheckpoint", - ): - with self.assertRaises(NotImplementedError): - getattr(df, f)() - def test_sql_with_command(self): # SPARK-42705: spark.sql should return values from the command. self.assertEqual( self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect() ) - def test_unsupported_jvm_attribute(self): - # Unsupported jvm attributes for Spark session. - unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"] - spark_session = self.connect - for attr in unsupported_attrs: - with self.assertRaises(PySparkAttributeError) as pe: - getattr(spark_session, attr) - - self.check_error( - exception=pe.exception, - error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", - message_parameters={"attr_name": attr}, - ) - - # Unsupported jvm attributes for DataFrame. - unsupported_attrs = ["_jseq", "_jdf", "_jmap", "_jcols"] - cdf = self.connect.range(10) - for attr in unsupported_attrs: - with self.assertRaises(PySparkAttributeError) as pe: - getattr(cdf, attr) - - self.check_error( - exception=pe.exception, - error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", - message_parameters={"attr_name": attr}, - ) - - # Unsupported jvm attributes for Column. - with self.assertRaises(PySparkAttributeError) as pe: - getattr(cdf.id, "_jc") - - self.check_error( - exception=pe.exception, - error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", - message_parameters={"attr_name": "_jc"}, - ) - - # Unsupported jvm attributes for DataFrameReader. - with self.assertRaises(PySparkAttributeError) as pe: - getattr(spark_session.read, "_jreader") - - self.check_error( - exception=pe.exception, - error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", - message_parameters={"attr_name": "_jreader"}, - ) - def test_df_caache(self): df = self.connect.range(10) df.cache() diff --git a/python/pyspark/sql/tests/connect/test_connect_creation.py b/python/pyspark/sql/tests/connect/test_connect_creation.py index 602453bd71d9..118e11161b15 100644 --- a/python/pyspark/sql/tests/connect/test_connect_creation.py +++ b/python/pyspark/sql/tests/connect/test_connect_creation.py @@ -27,12 +27,13 @@ from pyspark.sql.types import ( StructField, StringType, IntegerType, + LongType, MapType, ArrayType, Row, ) +from pyspark.testing.sqlutils import MyObject, PythonOnlyUDT -from pyspark.testing.sqlutils import MyObject from pyspark.testing.connectutils import should_test_connect from pyspark.errors.exceptions.connect import ParseException from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase @@ -689,6 +690,22 @@ class SparkConnectCreationTests(SparkConnectSQLTestCase): rows = [cols] * row_count self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count()) + def test_simple_udt(self): + from pyspark.ml.linalg import MatrixUDT, VectorUDT + + for schema in [ + StructType().add("key", LongType()).add("val", PythonOnlyUDT()), + StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())), + StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())), + StructType().add("key", LongType()).add("val", PythonOnlyUDT()), + StructType().add("key", LongType()).add("vec", VectorUDT()), + StructType().add("key", LongType()).add("mat", MatrixUDT()), + ]: + cdf = self.connect.createDataFrame(data=[], schema=schema) + sdf = self.spark.createDataFrame(data=[], schema=schema) + + self.assertEqual(cdf.schema, sdf.schema) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_creation import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_error.py b/python/pyspark/sql/tests/connect/test_connect_error.py new file mode 100644 index 000000000000..1297e62bb96f --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_connect_error.py @@ -0,0 +1,230 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.errors import PySparkAttributeError +from pyspark.errors.exceptions.base import SessionNotSameException +from pyspark.sql.types import Row +from pyspark.testing.connectutils import should_test_connect +from pyspark.errors.exceptions.connect import AnalysisException +from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase + +if should_test_connect: + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + from pyspark.sql.connect import functions as CF + from pyspark.sql.connect.column import Column + + +class SparkConnectErrorTests(SparkConnectSQLTestCase): + def test_recursion_handling_for_plan_logging(self): + """SPARK-45852 - Test that we can handle recursion in plan logging.""" + cdf = self.connect.range(1) + for x in range(400): + cdf = cdf.withColumn(f"col_{x}", CF.lit(x)) + + # Calling schema will trigger logging the message that will in turn trigger the message + # conversion into protobuf that will then trigger the recursion error. + self.assertIsNotNone(cdf.schema) + + result = self.connect._client._proto_to_string(cdf._plan.to_proto(self.connect._client)) + self.assertIn("recursion", result) + + def test_error_handling(self): + # SPARK-41533 Proper error handling for Spark Connect + df = self.connect.range(10).select("id2") + with self.assertRaises(AnalysisException): + df.collect() + + def test_invalid_column(self): + # SPARK-41812: fail df1.select(df2.col) + data1 = [Row(a=1, b=2, c=3)] + cdf1 = self.connect.createDataFrame(data1) + + data2 = [Row(a=2, b=0)] + cdf2 = self.connect.createDataFrame(data2) + + with self.assertRaises(AnalysisException): + cdf1.select(cdf2.a).schema + + with self.assertRaises(AnalysisException): + cdf2.withColumn("x", cdf1.a + 1).schema + + # Can find the target plan node, but fail to resolve with it + with self.assertRaisesRegex( + AnalysisException, + "UNRESOLVED_COLUMN.WITH_SUGGESTION", + ): + cdf3 = cdf1.select(cdf1.a) + cdf3.select(cdf1.b).schema + + # Can not find the target plan node by plan id + with self.assertRaisesRegex( + AnalysisException, + "CANNOT_RESOLVE_DATAFRAME_COLUMN", + ): + cdf1.select(cdf2.a).schema + + def test_invalid_star(self): + data1 = [Row(a=1, b=2, c=3)] + cdf1 = self.connect.createDataFrame(data1) + + data2 = [Row(a=2, b=0)] + cdf2 = self.connect.createDataFrame(data2) + + # Can find the target plan node, but fail to resolve with it + with self.assertRaisesRegex( + AnalysisException, + "CANNOT_RESOLVE_DATAFRAME_COLUMN", + ): + cdf3 = cdf1.select(cdf1.a) + cdf3.select(cdf1["*"]).schema + + # Can find the target plan node, but fail to resolve with it + with self.assertRaisesRegex( + AnalysisException, + "CANNOT_RESOLVE_DATAFRAME_COLUMN", + ): + # column 'a has been replaced + cdf3 = cdf1.withColumn("a", CF.lit(0)) + cdf3.select(cdf1["*"]).schema + + # Can not find the target plan node by plan id + with self.assertRaisesRegex( + AnalysisException, + "CANNOT_RESOLVE_DATAFRAME_COLUMN", + ): + cdf1.select(cdf2["*"]).schema + + # cdf1["*"] exists on both side + with self.assertRaisesRegex( + AnalysisException, + "AMBIGUOUS_COLUMN_REFERENCE", + ): + cdf1.join(cdf1).select(cdf1["*"]).schema + + def test_deduplicate_within_watermark_in_batch(self): + df = self.connect.read.table(self.tbl_name) + with self.assertRaisesRegex( + AnalysisException, + "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets", + ): + df.dropDuplicatesWithinWatermark().toPandas() + + def test_different_spark_session_join_or_union(self): + df = self.connect.range(10).limit(3) + + spark2 = RemoteSparkSession(connection="sc://localhost") + df2 = spark2.range(10).limit(3) + + with self.assertRaises(SessionNotSameException) as e1: + df.union(df2).collect() + self.check_error( + exception=e1.exception, + error_class="SESSION_NOT_SAME", + message_parameters={}, + ) + + with self.assertRaises(SessionNotSameException) as e2: + df.unionByName(df2).collect() + self.check_error( + exception=e2.exception, + error_class="SESSION_NOT_SAME", + message_parameters={}, + ) + + with self.assertRaises(SessionNotSameException) as e3: + df.join(df2).collect() + self.check_error( + exception=e3.exception, + error_class="SESSION_NOT_SAME", + message_parameters={}, + ) + + def test_unsupported_functions(self): + # SPARK-41225: Disable unsupported functions. + df = self.connect.read.table(self.tbl_name) + for f in ( + "checkpoint", + "localCheckpoint", + ): + with self.assertRaises(NotImplementedError): + getattr(df, f)() + + def test_unsupported_jvm_attribute(self): + # Unsupported jvm attributes for Spark session. + unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"] + spark_session = self.connect + for attr in unsupported_attrs: + with self.assertRaises(PySparkAttributeError) as pe: + getattr(spark_session, attr) + + self.check_error( + exception=pe.exception, + error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", + message_parameters={"attr_name": attr}, + ) + + # Unsupported jvm attributes for DataFrame. + unsupported_attrs = ["_jseq", "_jdf", "_jmap", "_jcols"] + cdf = self.connect.range(10) + for attr in unsupported_attrs: + with self.assertRaises(PySparkAttributeError) as pe: + getattr(cdf, attr) + + self.check_error( + exception=pe.exception, + error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", + message_parameters={"attr_name": attr}, + ) + + # Unsupported jvm attributes for Column. + with self.assertRaises(PySparkAttributeError) as pe: + getattr(cdf.id, "_jc") + + self.check_error( + exception=pe.exception, + error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", + message_parameters={"attr_name": "_jc"}, + ) + + # Unsupported jvm attributes for DataFrameReader. + with self.assertRaises(PySparkAttributeError) as pe: + getattr(spark_session.read, "_jreader") + + self.check_error( + exception=pe.exception, + error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", + message_parameters={"attr_name": "_jreader"}, + ) + + def test_column_cannot_be_constructed_from_string(self): + with self.assertRaises(TypeError): + Column("col") + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_connect_error import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py index 6554fc961d79..db1e94cb6863 100644 --- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py @@ -26,8 +26,15 @@ from pyspark.sql.types import ( LongType, StringType, IntegerType, + ArrayType, + MapType, Row, ) +from pyspark.testing.sqlutils import ( + PythonOnlyUDT, + ExamplePoint, + PythonOnlyPoint, +) from pyspark.testing.connectutils import should_test_connect from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase @@ -271,6 +278,62 @@ class SparkConnectReadWriterTests(SparkConnectSQLTestCase): self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2) self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours("ts")), DataFrameWriterV2) + def test_simple_udt_from_read(self): + from pyspark.ml.linalg import Matrices, Vectors + + with tempfile.TemporaryDirectory(prefix="test_simple_udt_from_read") as d: + path1 = f"{d}/df1.parquet" + self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()), + ).write.parquet(path1) + + path2 = f"{d}/df2.parquet" + self.spark.createDataFrame( + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())), + ).write.parquet(path2) + + path3 = f"{d}/df3.parquet" + self.spark.createDataFrame( + [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], + schema=StructType() + .add("key", LongType()) + .add("val", MapType(LongType(), PythonOnlyUDT())), + ).write.parquet(path3) + + path4 = f"{d}/df4.parquet" + self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()), + ).write.parquet(path4) + + path5 = f"{d}/df5.parquet" + self.spark.createDataFrame( + [Row(label=1.0, point=ExamplePoint(1.0, 2.0))] + ).write.parquet(path5) + + path6 = f"{d}/df6.parquet" + self.spark.createDataFrame( + [(Vectors.dense(1.0, 2.0, 3.0),), (Vectors.sparse(3, {1: 1.0, 2: 5.5}),)], + ["vec"], + ).write.parquet(path6) + + path7 = f"{d}/df7.parquet" + self.spark.createDataFrame( + [ + (Matrices.dense(3, 2, [0, 1, 4, 5, 9, 10]),), + (Matrices.sparse(1, 1, [0, 1], [0], [2.0]),), + ], + ["mat"], + ).write.parquet(path7) + + for path in [path1, path2, path3, path4, path5, path6, path7]: + self.assertEqual( + self.connect.read.parquet(path).schema, + self.spark.read.parquet(path).schema, + ) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_readwriter import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_stat.py b/python/pyspark/sql/tests/connect/test_connect_stat.py new file mode 100644 index 000000000000..24165636202c --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_connect_stat.py @@ -0,0 +1,813 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.testing.connectutils import should_test_connect +from pyspark.errors.exceptions.connect import ( + AnalysisException, + SparkConnectException, +) +from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase + +if should_test_connect: + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + +class SparkConnectStatTests(SparkConnectSQLTestCase): + def test_fill_na(self): + # SPARK-41128: Test fill na + query = """ + SELECT * FROM VALUES + (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) + AS tab(a, b, c) + """ + # +-----+----+----+ + # | a| b| c| + # +-----+----+----+ + # |false| 1|NULL| + # |false|NULL| 2.0| + # | NULL| 3| 3.0| + # +-----+----+----+ + + self.assert_eq( + self.connect.sql(query).fillna(True).toPandas(), + self.spark.sql(query).fillna(True).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).fillna(2).toPandas(), + self.spark.sql(query).fillna(2).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).fillna(2, ["a", "b"]).toPandas(), + self.spark.sql(query).fillna(2, ["a", "b"]).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.fill({"a": True, "b": 2}).toPandas(), + self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(), + ) + + def test_drop_na(self): + # SPARK-41148: Test drop na + query = """ + SELECT * FROM VALUES + (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) + AS tab(a, b, c) + """ + # +-----+----+----+ + # | a| b| c| + # +-----+----+----+ + # |false| 1|NULL| + # |false|NULL| 2.0| + # | NULL| 3| 3.0| + # +-----+----+----+ + + self.assert_eq( + self.connect.sql(query).dropna().toPandas(), + self.spark.sql(query).dropna().toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(), + self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(), + self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(), + self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(), + ) + + def test_replace(self): + # SPARK-41315: Test replace + query = """ + SELECT * FROM VALUES + (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) + AS tab(a, b, c) + """ + # +-----+----+----+ + # | a| b| c| + # +-----+----+----+ + # |false| 1|NULL| + # |false|NULL| 2.0| + # | NULL| 3| 3.0| + # +-----+----+----+ + + self.assert_eq( + self.connect.sql(query).replace(2, 3).toPandas(), + self.spark.sql(query).replace(2, 3).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.replace(False, True).toPandas(), + self.spark.sql(query).na.replace(False, True).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(), + self.spark.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.replace((1, 2), (3, 1)).toPandas(), + self.spark.sql(query).na.replace((1, 2), (3, 1)).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(), + self.spark.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(), + ) + + with self.assertRaises(ValueError) as context: + self.connect.sql(query).replace({None: 1}, subset="a").toPandas() + self.assertTrue("Mixed type replacements are not supported" in str(context.exception)) + + with self.assertRaises(AnalysisException) as context: + self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas() + self.assertIn( + """Cannot resolve column name "x" among (a, b, c)""", str(context.exception) + ) + + def test_random_split(self): + # SPARK-41440: test randomSplit(weights, seed). + relations = ( + self.connect.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2) + ) + datasets = ( + self.spark.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2) + ) + + self.assertTrue(len(relations) == len(datasets)) + i = 0 + while i < len(relations): + self.assert_eq(relations[i].toPandas(), datasets[i].toPandas()) + i += 1 + + def test_describe(self): + # SPARK-41403: Test the describe method + self.assert_eq( + self.connect.read.table(self.tbl_name).describe("id").toPandas(), + self.spark.read.table(self.tbl_name).describe("id").toPandas(), + ) + self.assert_eq( + self.connect.read.table(self.tbl_name).describe("id", "name").toPandas(), + self.spark.read.table(self.tbl_name).describe("id", "name").toPandas(), + ) + self.assert_eq( + self.connect.read.table(self.tbl_name).describe(["id", "name"]).toPandas(), + self.spark.read.table(self.tbl_name).describe(["id", "name"]).toPandas(), + ) + + def test_stat_cov(self): + # SPARK-41067: Test the stat.cov method + self.assertEqual( + self.connect.read.table(self.tbl_name2).stat.cov("col1", "col3"), + self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"), + ) + + def test_stat_corr(self): + # SPARK-41068: Test the stat.corr method + self.assertEqual( + self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"), + self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"), + ) + + self.assertEqual( + self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"), + self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"), + ) + + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name2).stat.corr(1, "col3", "pearson") + + self.check_error( + exception=pe.exception, + error_class="NOT_STR", + message_parameters={ + "arg_name": "col1", + "arg_type": "int", + }, + ) + + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name).stat.corr("col1", 1, "pearson") + + self.check_error( + exception=pe.exception, + error_class="NOT_STR", + message_parameters={ + "arg_name": "col2", + "arg_type": "int", + }, + ) + with self.assertRaises(ValueError) as context: + self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "spearman"), + self.assertTrue( + "Currently only the calculation of the Pearson Correlation " + + "coefficient is supported." + in str(context.exception) + ) + + def test_stat_approx_quantile(self): + # SPARK-41069: Test the stat.approxQuantile method + result = self.connect.read.table(self.tbl_name2).stat.approxQuantile( + ["col1", "col3"], [0.1, 0.5, 0.9], 0.1 + ) + self.assertEqual(len(result), 2) + self.assertEqual(len(result[0]), 3) + self.assertEqual(len(result[1]), 3) + + result = self.connect.read.table(self.tbl_name2).stat.approxQuantile( + ["col1"], [0.1, 0.5, 0.9], 0.1 + ) + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 3) + + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name2).stat.approxQuantile(1, [0.1, 0.5, 0.9], 0.1) + + self.check_error( + exception=pe.exception, + error_class="NOT_LIST_OR_STR_OR_TUPLE", + message_parameters={ + "arg_name": "col", + "arg_type": "int", + }, + ) + + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name2).stat.approxQuantile(["col1", "col3"], 0.1, 0.1) + + self.check_error( + exception=pe.exception, + error_class="NOT_LIST_OR_TUPLE", + message_parameters={ + "arg_name": "probabilities", + "arg_type": "float", + }, + ) + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name2).stat.approxQuantile( + ["col1", "col3"], [-0.1], 0.1 + ) + + self.check_error( + exception=pe.exception, + error_class="NOT_LIST_OF_FLOAT_OR_INT", + message_parameters={"arg_name": "probabilities", "arg_type": "float"}, + ) + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name2).stat.approxQuantile( + ["col1", "col3"], [0.1, 0.5, 0.9], "str" + ) + + self.check_error( + exception=pe.exception, + error_class="NOT_FLOAT_OR_INT", + message_parameters={ + "arg_name": "relativeError", + "arg_type": "str", + }, + ) + with self.assertRaises(PySparkValueError) as pe: + self.connect.read.table(self.tbl_name2).stat.approxQuantile( + ["col1", "col3"], [0.1, 0.5, 0.9], -0.1 + ) + + self.check_error( + exception=pe.exception, + error_class="NEGATIVE_VALUE", + message_parameters={ + "arg_name": "relativeError", + "arg_value": "-0.1", + }, + ) + + def test_stat_freq_items(self): + # SPARK-41065: Test the stat.freqItems method + self.assert_eq( + self.connect.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(), + self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(), + check_exact=False, + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name2) + .stat.freqItems(["col1", "col3"], 0.4) + .toPandas(), + self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"], 0.4).toPandas(), + ) + + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name2).stat.freqItems("col1") + + self.check_error( + exception=pe.exception, + error_class="NOT_LIST_OR_TUPLE", + message_parameters={ + "arg_name": "cols", + "arg_type": "str", + }, + ) + + def test_stat_sample_by(self): + # SPARK-41069: Test stat.sample_by + + cdf = self.connect.range(0, 100).select((CF.col("id") % 3).alias("key")) + sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key")) + + self.assert_eq( + cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 0.2}, seed=0) + .groupBy("key") + .agg(CF.count(CF.lit(1))) + .orderBy("key") + .toPandas(), + sdf.sampleBy(sdf.key, fractions={0: 0.1, 1: 0.2}, seed=0) + .groupBy("key") + .agg(SF.count(SF.lit(1))) + .orderBy("key") + .toPandas(), + ) + + with self.assertRaises(PySparkTypeError) as pe: + cdf.stat.sampleBy(cdf.key, fractions={0: 0.1, None: 0.2}, seed=0) + + self.check_error( + exception=pe.exception, + error_class="DISALLOWED_TYPE_FOR_CONTAINER", + message_parameters={ + "arg_name": "fractions", + "arg_type": "dict", + "allowed_types": "float, int, str", + "item_type": "NoneType", + }, + ) + + with self.assertRaises(SparkConnectException): + cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 1.2}, seed=0).show() + + def test_subtract(self): + # SPARK-41453: test dataframe.subtract() + ndf1 = self.connect.read.table(self.tbl_name) + ndf2 = ndf1.filter("id > 3") + df1 = self.spark.read.table(self.tbl_name) + df2 = df1.filter("id > 3") + + self.assert_eq( + ndf1.subtract(ndf2).toPandas(), + df1.subtract(df2).toPandas(), + ) + + def test_agg_with_avg(self): + # SPARK-41325: groupby.avg() + df = ( + self.connect.range(10) + .groupBy((CF.col("id") % CF.lit(2)).alias("moded")) + .avg("id") + .sort("moded") + ) + res = df.collect() + self.assertEqual(2, len(res)) + self.assertEqual(4.0, res[0][1]) + self.assertEqual(5.0, res[1][1]) + + # Additional GroupBy tests with 3 rows + + df_a = self.connect.range(10).groupBy((CF.col("id") % CF.lit(3)).alias("moded")) + df_b = self.spark.range(10).groupBy((SF.col("id") % SF.lit(3)).alias("moded")) + self.assertEqual( + set(df_b.agg(SF.sum("id")).collect()), set(df_a.agg(CF.sum("id")).collect()) + ) + + # Dict agg + measures = {"id": "sum"} + self.assertEqual( + set(df_a.agg(measures).select("sum(id)").collect()), + set(df_b.agg(measures).select("sum(id)").collect()), + ) + + def test_agg_with_two_agg_exprs(self) -> None: + # SPARK-41230: test dataframe.agg() + self.assert_eq( + self.connect.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(), + self.spark.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(), + ) + + def test_grouped_data(self): + query = """ + SELECT * FROM VALUES + ('James', 'Sales', 3000, 2020), + ('Michael', 'Sales', 4600, 2020), + ('Robert', 'Sales', 4100, 2020), + ('Maria', 'Finance', 3000, 2020), + ('James', 'Sales', 3000, 2019), + ('Scott', 'Finance', 3300, 2020), + ('Jen', 'Finance', 3900, 2020), + ('Jeff', 'Marketing', 3000, 2020), + ('Kumar', 'Marketing', 2000, 2020), + ('Saif', 'Sales', 4100, 2020) + AS T(name, department, salary, year) + """ + + # +-------+----------+------+----+ + # | name|department|salary|year| + # +-------+----------+------+----+ + # | James| Sales| 3000|2020| + # |Michael| Sales| 4600|2020| + # | Robert| Sales| 4100|2020| + # | Maria| Finance| 3000|2020| + # | James| Sales| 3000|2019| + # | Scott| Finance| 3300|2020| + # | Jen| Finance| 3900|2020| + # | Jeff| Marketing| 3000|2020| + # | Kumar| Marketing| 2000|2020| + # | Saif| Sales| 4100|2020| + # +-------+----------+------+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test groupby + self.assert_eq( + cdf.groupBy("name").agg(CF.sum(cdf.salary)).toPandas(), + sdf.groupBy("name").agg(SF.sum(sdf.salary)).toPandas(), + ) + self.assert_eq( + cdf.groupBy("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), + sdf.groupBy("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), + ) + + # test rollup + self.assert_eq( + cdf.rollup("name").agg(CF.sum(cdf.salary)).toPandas(), + sdf.rollup("name").agg(SF.sum(sdf.salary)).toPandas(), + ) + self.assert_eq( + cdf.rollup("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), + sdf.rollup("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), + ) + + # test cube + self.assert_eq( + cdf.cube("name").agg(CF.sum(cdf.salary)).toPandas(), + sdf.cube("name").agg(SF.sum(sdf.salary)).toPandas(), + ) + self.assert_eq( + cdf.cube("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), + sdf.cube("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), + ) + + # test pivot + # pivot with values + self.assert_eq( + cdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .agg(CF.sum(cdf.salary)) + .toPandas(), + sdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .agg(SF.sum(sdf.salary)) + .toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name) + .pivot("department", ["Sales", "Finance", "Marketing"]) + .agg(CF.sum(cdf.salary)) + .toPandas(), + sdf.groupBy(sdf.name) + .pivot("department", ["Sales", "Finance", "Marketing"]) + .agg(SF.sum(sdf.salary)) + .toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .agg(CF.sum(cdf.salary)) + .toPandas(), + sdf.groupBy(sdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .agg(SF.sum(sdf.salary)) + .toPandas(), + ) + + # pivot without values + self.assert_eq( + cdf.groupBy("name").pivot("department").agg(CF.sum(cdf.salary)).toPandas(), + sdf.groupBy("name").pivot("department").agg(SF.sum(sdf.salary)).toPandas(), + ) + + self.assert_eq( + cdf.groupBy("name").pivot("year").agg(CF.sum(cdf.salary)).toPandas(), + sdf.groupBy("name").pivot("year").agg(SF.sum(sdf.salary)).toPandas(), + ) + + # check error + with self.assertRaisesRegex( + Exception, + "PIVOT after ROLLUP is not supported", + ): + cdf.rollup("name").pivot("department").agg(CF.sum(cdf.salary)) + + with self.assertRaisesRegex( + Exception, + "PIVOT after CUBE is not supported", + ): + cdf.cube("name").pivot("department").agg(CF.sum(cdf.salary)) + + with self.assertRaisesRegex( + Exception, + "Repeated PIVOT operation is not supported", + ): + cdf.groupBy("name").pivot("year").pivot("year").agg(CF.sum(cdf.salary)) + + with self.assertRaises(PySparkTypeError) as pe: + cdf.groupBy("name").pivot("department", ["Sales", b"Marketing"]).agg(CF.sum(cdf.salary)) + + self.check_error( + exception=pe.exception, + error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR", + message_parameters={ + "arg_name": "value", + "arg_type": "bytes", + }, + ) + + def test_numeric_aggregation(self): + # SPARK-41737: test numeric aggregation + query = """ + SELECT * FROM VALUES + ('James', 'Sales', 3000, 2020), + ('Michael', 'Sales', 4600, 2020), + ('Robert', 'Sales', 4100, 2020), + ('Maria', 'Finance', 3000, 2020), + ('James', 'Sales', 3000, 2019), + ('Scott', 'Finance', 3300, 2020), + ('Jen', 'Finance', 3900, 2020), + ('Jeff', 'Marketing', 3000, 2020), + ('Kumar', 'Marketing', 2000, 2020), + ('Saif', 'Sales', 4100, 2020) + AS T(name, department, salary, year) + """ + + # +-------+----------+------+----+ + # | name|department|salary|year| + # +-------+----------+------+----+ + # | James| Sales| 3000|2020| + # |Michael| Sales| 4600|2020| + # | Robert| Sales| 4100|2020| + # | Maria| Finance| 3000|2020| + # | James| Sales| 3000|2019| + # | Scott| Finance| 3300|2020| + # | Jen| Finance| 3900|2020| + # | Jeff| Marketing| 3000|2020| + # | Kumar| Marketing| 2000|2020| + # | Saif| Sales| 4100|2020| + # +-------+----------+------+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test groupby + self.assert_eq( + cdf.groupBy("name").min().toPandas(), + sdf.groupBy("name").min().toPandas(), + ) + self.assert_eq( + cdf.groupBy("name").min("salary").toPandas(), + sdf.groupBy("name").min("salary").toPandas(), + ) + self.assert_eq( + cdf.groupBy("name").max("salary").toPandas(), + sdf.groupBy("name").max("salary").toPandas(), + ) + self.assert_eq( + cdf.groupBy("name", cdf.department).avg("salary", "year").toPandas(), + sdf.groupBy("name", sdf.department).avg("salary", "year").toPandas(), + ) + self.assert_eq( + cdf.groupBy("name", cdf.department).mean("salary", "year").toPandas(), + sdf.groupBy("name", sdf.department).mean("salary", "year").toPandas(), + ) + self.assert_eq( + cdf.groupBy("name", cdf.department).sum("salary", "year").toPandas(), + sdf.groupBy("name", sdf.department).sum("salary", "year").toPandas(), + ) + + # test rollup + self.assert_eq( + cdf.rollup("name").max().toPandas(), + sdf.rollup("name").max().toPandas(), + ) + self.assert_eq( + cdf.rollup("name").min("salary").toPandas(), + sdf.rollup("name").min("salary").toPandas(), + ) + self.assert_eq( + cdf.rollup("name").max("salary").toPandas(), + sdf.rollup("name").max("salary").toPandas(), + ) + self.assert_eq( + cdf.rollup("name", cdf.department).avg("salary", "year").toPandas(), + sdf.rollup("name", sdf.department).avg("salary", "year").toPandas(), + ) + self.assert_eq( + cdf.rollup("name", cdf.department).mean("salary", "year").toPandas(), + sdf.rollup("name", sdf.department).mean("salary", "year").toPandas(), + ) + self.assert_eq( + cdf.rollup("name", cdf.department).sum("salary", "year").toPandas(), + sdf.rollup("name", sdf.department).sum("salary", "year").toPandas(), + ) + + # test cube + self.assert_eq( + cdf.cube("name").avg().toPandas(), + sdf.cube("name").avg().toPandas(), + ) + self.assert_eq( + cdf.cube("name").mean().toPandas(), + sdf.cube("name").mean().toPandas(), + ) + self.assert_eq( + cdf.cube("name").min("salary").toPandas(), + sdf.cube("name").min("salary").toPandas(), + ) + self.assert_eq( + cdf.cube("name").max("salary").toPandas(), + sdf.cube("name").max("salary").toPandas(), + ) + self.assert_eq( + cdf.cube("name", cdf.department).avg("salary", "year").toPandas(), + sdf.cube("name", sdf.department).avg("salary", "year").toPandas(), + ) + self.assert_eq( + cdf.cube("name", cdf.department).sum("salary", "year").toPandas(), + sdf.cube("name", sdf.department).sum("salary", "year").toPandas(), + ) + + # test pivot + # pivot with values + self.assert_eq( + cdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(), + sdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(), + ) + self.assert_eq( + cdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .min("salary") + .toPandas(), + sdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .min("salary") + .toPandas(), + ) + self.assert_eq( + cdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .max("salary") + .toPandas(), + sdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .max("salary") + .toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .avg("salary", "year") + .toPandas(), + sdf.groupBy(sdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .avg("salary", "year") + .toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .sum("salary", "year") + .toPandas(), + sdf.groupBy(sdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .sum("salary", "year") + .toPandas(), + ) + + # pivot without values + self.assert_eq( + cdf.groupBy("name").pivot("department").min().toPandas(), + sdf.groupBy("name").pivot("department").min().toPandas(), + ) + self.assert_eq( + cdf.groupBy("name").pivot("department").min("salary").toPandas(), + sdf.groupBy("name").pivot("department").min("salary").toPandas(), + ) + self.assert_eq( + cdf.groupBy("name").pivot("department").max("salary").toPandas(), + sdf.groupBy("name").pivot("department").max("salary").toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name).pivot("department").avg("salary", "year").toPandas(), + sdf.groupBy(sdf.name).pivot("department").avg("salary", "year").toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name).pivot("department").sum("salary", "year").toPandas(), + sdf.groupBy(sdf.name).pivot("department").sum("salary", "year").toPandas(), + ) + + # check error + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.groupBy("name").min("department").show() + + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.groupBy("name").max("salary", "department").show() + + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.rollup("name").avg("department").show() + + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.rollup("name").sum("salary", "department").show() + + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.cube("name").min("department").show() + + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.cube("name").max("salary", "department").show() + + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.groupBy("name").pivot("department").avg("department").show() + + with self.assertRaisesRegex( + TypeError, + "Numeric aggregation function can only be applied on numeric columns", + ): + cdf.groupBy("name").pivot("department").sum("salary", "department").show() + + def test_unpivot(self): + self.assert_eq( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .unpivot(["id"], ["name"], "variable", "value") + .toPandas(), + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .unpivot(["id"], ["name"], "variable", "value") + .toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .unpivot("id", None, "variable", "value") + .toPandas(), + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .unpivot("id", None, "variable", "value") + .toPandas(), + ) + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_connect_stat import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org