This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 01c7a46f24f [SPARK-40539][CONNECT] Initial DataFrame Read API parity for Spark Connect 01c7a46f24f is described below commit 01c7a46f24fb4bb4287a184a3d69e0e5c904bc50 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Thu Oct 20 09:26:06 2022 +0800 [SPARK-40539][CONNECT] Initial DataFrame Read API parity for Spark Connect ### What changes were proposed in this pull request? Add initial Read API for Spark Connect that allows setting schema, format, option and path, and then to read files into DataFrame. ### Why are the changes needed? PySpark readwriter API parity for Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38086 from amaliujia/SPARK-40539. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 10 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 15 ++- .../connect/planner/SparkConnectPlannerSuite.scala | 11 ++ python/pyspark/sql/connect/plan.py | 41 +++++++ python/pyspark/sql/connect/proto/relations_pb2.py | 78 +++++++------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 56 +++++++++- python/pyspark/sql/connect/readwriter.py | 118 ++++++++++++++++++++- .../sql/tests/connect/test_connect_basic.py | 20 ++++ .../sql/tests/connect/test_connect_plan_only.py | 13 +++ 9 files changed, 320 insertions(+), 42 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 353fbebd046..eadedf495d3 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -67,11 +67,21 @@ message SQL { message Read { oneof read_type { NamedTable named_table = 1; + DataSource data_source = 2; } message NamedTable { string unparsed_identifier = 1; } + + message DataSource { + // Required. Supported formats include: parquet, orc, text, json, parquet, csv, avro. + string format = 1; + // Optional. If not set, Spark will infer the schema. + string schema = 2; + // The key is case insensitive. + map<string, string> options = 3; + } } // Projection of a bag of expressions for a given input relation. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6a6b5a15a08..450283a9b81 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeRef import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sample, SubqueryAlias} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.types._ final case class InvalidPlanInput( @@ -112,7 +113,19 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } else { child } - case _ => throw InvalidPlanInput() + case proto.Read.ReadTypeCase.DATA_SOURCE => + if (rel.getDataSource.getFormat == "") { + throw InvalidPlanInput("DataSource requires a format") + } + val localMap = CaseInsensitiveMap[String](rel.getDataSource.getOptionsMap.asScala.toMap) + val reader = session.read + reader.format(rel.getDataSource.getFormat) + localMap.foreach { case (key, value) => reader.option(key, value) } + if (rel.getDataSource.getSchema != null) { + reader.schema(rel.getDataSource.getSchema) + } + reader.load().queryExecution.analyzed + case _ => throw InvalidPlanInput("Does not support " + rel.getReadTypeCase.name()) } baseRelation } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index fc3d219ec6b..83bf76efce1 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -255,4 +255,15 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { assert(res.nodeName == "Aggregate") } + test("Invalid DataSource") { + val dataSource = proto.Read.DataSource.newBuilder() + + val e = intercept[InvalidPlanInput]( + transform( + proto.Relation + .newBuilder() + .setRead(proto.Read.newBuilder().setDataSource(dataSource)) + .build())) + assert(e.getMessage.contains("DataSource requires a format")) + } } diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 5fcd468924d..c564b71cdba 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -23,6 +23,7 @@ from typing import ( Union, cast, TYPE_CHECKING, + Mapping, ) import pyspark.sql.connect.proto as proto @@ -111,6 +112,46 @@ class LogicalPlan(object): return self._child._repr_html_() if self._child is not None else "" +class DataSource(LogicalPlan): + """A datasource with a format and optional a schema from which Spark reads data""" + + def __init__( + self, + format: str = "", + schema: Optional[str] = None, + options: Optional[Mapping[str, str]] = None, + ) -> None: + super().__init__(None) + self.format = format + self.schema = schema + self.options = options + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + plan = proto.Relation() + if self.format is not None: + plan.read.data_source.format = self.format + if self.schema is not None: + plan.read.data_source.schema = self.schema + if self.options is not None: + for k in self.options.keys(): + v = self.options.get(k) + if v is not None: + plan.read.data_source.options[k] = v + return plan + + def _repr_html_(self) -> str: + return f""" + <ul> + <li> + <b>DataSource</b><br /> + format: {self.format} + schema: {self.schema} + options: {self.options} + </li> + </ul> + """ + + class Read(LogicalPlan): def __init__(self, table_name: str) -> None: super().__init__(None) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index c3b7b7ec2ea..b244cdf8dcb 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -41,6 +41,8 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" + _READ_DATASOURCE_OPTIONSENTRY._options = None + _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 _RELATION._serialized_end = 801 _UNKNOWN._serialized_start = 803 @@ -50,39 +52,43 @@ if _descriptor._USE_C_DESCRIPTORS == False: _SQL._serialized_start = 887 _SQL._serialized_end = 914 _READ._serialized_start = 917 - _READ._serialized_end = 1066 - _READ_NAMEDTABLE._serialized_start = 992 - _READ_NAMEDTABLE._serialized_end = 1053 - _PROJECT._serialized_start = 1068 - _PROJECT._serialized_end = 1185 - _FILTER._serialized_start = 1187 - _FILTER._serialized_end = 1299 - _JOIN._serialized_start = 1302 - _JOIN._serialized_end = 1715 - _JOIN_JOINTYPE._serialized_start = 1528 - _JOIN_JOINTYPE._serialized_end = 1715 - _UNION._serialized_start = 1718 - _UNION._serialized_end = 1923 - _UNION_UNIONTYPE._serialized_start = 1839 - _UNION_UNIONTYPE._serialized_end = 1923 - _LIMIT._serialized_start = 1925 - _LIMIT._serialized_end = 2001 - _OFFSET._serialized_start = 2003 - _OFFSET._serialized_end = 2082 - _AGGREGATE._serialized_start = 2085 - _AGGREGATE._serialized_end = 2410 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2314 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2410 - _SORT._serialized_start = 2413 - _SORT._serialized_end = 2915 - _SORT_SORTFIELD._serialized_start = 2533 - _SORT_SORTFIELD._serialized_end = 2721 - _SORT_SORTDIRECTION._serialized_start = 2723 - _SORT_SORTDIRECTION._serialized_end = 2831 - _SORT_SORTNULLS._serialized_start = 2833 - _SORT_SORTNULLS._serialized_end = 2915 - _LOCALRELATION._serialized_start = 2917 - _LOCALRELATION._serialized_end = 3010 - _SAMPLE._serialized_start = 3013 - _SAMPLE._serialized_end = 3197 + _READ._serialized_end = 1327 + _READ_NAMEDTABLE._serialized_start = 1059 + _READ_NAMEDTABLE._serialized_end = 1120 + _READ_DATASOURCE._serialized_start = 1123 + _READ_DATASOURCE._serialized_end = 1314 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1256 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1314 + _PROJECT._serialized_start = 1329 + _PROJECT._serialized_end = 1446 + _FILTER._serialized_start = 1448 + _FILTER._serialized_end = 1560 + _JOIN._serialized_start = 1563 + _JOIN._serialized_end = 1976 + _JOIN_JOINTYPE._serialized_start = 1789 + _JOIN_JOINTYPE._serialized_end = 1976 + _UNION._serialized_start = 1979 + _UNION._serialized_end = 2184 + _UNION_UNIONTYPE._serialized_start = 2100 + _UNION_UNIONTYPE._serialized_end = 2184 + _LIMIT._serialized_start = 2186 + _LIMIT._serialized_end = 2262 + _OFFSET._serialized_start = 2264 + _OFFSET._serialized_end = 2343 + _AGGREGATE._serialized_start = 2346 + _AGGREGATE._serialized_end = 2671 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2575 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2671 + _SORT._serialized_start = 2674 + _SORT._serialized_end = 3176 + _SORT_SORTFIELD._serialized_start = 2794 + _SORT_SORTFIELD._serialized_end = 2982 + _SORT_SORTDIRECTION._serialized_start = 2984 + _SORT_SORTDIRECTION._serialized_end = 3092 + _SORT_SORTNULLS._serialized_start = 3094 + _SORT_SORTNULLS._serialized_end = 3176 + _LOCALRELATION._serialized_start = 3178 + _LOCALRELATION._serialized_end = 3271 + _SAMPLE._serialized_start = 3274 + _SAMPLE._serialized_end = 3458 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 3354fc86f45..f0a8b6412b5 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -280,29 +280,79 @@ class Read(google.protobuf.message.Message): field_name: typing_extensions.Literal["unparsed_identifier", b"unparsed_identifier"], ) -> None: ... + class DataSource(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class OptionsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + FORMAT_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int + OPTIONS_FIELD_NUMBER: builtins.int + format: builtins.str + """Required. Supported formats include: parquet, orc, text, json, parquet, csv, avro.""" + schema: builtins.str + """Optional. If not set, Spark will infer the schema.""" + @property + def options( + self, + ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """The key is case insensitive.""" + def __init__( + self, + *, + format: builtins.str = ..., + schema: builtins.str = ..., + options: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "format", b"format", "options", b"options", "schema", b"schema" + ], + ) -> None: ... + NAMED_TABLE_FIELD_NUMBER: builtins.int + DATA_SOURCE_FIELD_NUMBER: builtins.int @property def named_table(self) -> global___Read.NamedTable: ... + @property + def data_source(self) -> global___Read.DataSource: ... def __init__( self, *, named_table: global___Read.NamedTable | None = ..., + data_source: global___Read.DataSource | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "named_table", b"named_table", "read_type", b"read_type" + "data_source", b"data_source", "named_table", b"named_table", "read_type", b"read_type" ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "named_table", b"named_table", "read_type", b"read_type" + "data_source", b"data_source", "named_table", b"named_table", "read_type", b"read_type" ], ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["read_type", b"read_type"] - ) -> typing_extensions.Literal["named_table"] | None: ... + ) -> typing_extensions.Literal["named_table", "data_source"] | None: ... global___Read = Read diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 285e78e59ae..66e48eeab76 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -15,8 +15,16 @@ # limitations under the License. # + +from typing import Dict, Optional + +from pyspark.sql.connect.column import PrimitiveType from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.plan import Read +from pyspark.sql.connect.plan import Read, DataSource +from pyspark.sql.utils import to_str + + +OptionalPrimitiveType = Optional[PrimitiveType] from typing import TYPE_CHECKING @@ -29,8 +37,114 @@ class DataFrameReader: TODO(SPARK-40539) Achieve parity with PySpark. """ - def __init__(self, client: "RemoteSparkSession") -> None: + def __init__(self, client: "RemoteSparkSession"): self._client = client + self._format = "" + self._schema = "" + self._options: Dict[str, str] = {} + + def format(self, source: str) -> "DataFrameReader": + """ + Specifies the input data source format. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + source : str + string, name of the data source, e.g. 'json', 'parquet'. + + """ + self._format = source + return self + + # TODO(SPARK-40539): support StructType in python client and support schema as StructType. + def schema(self, schema: str) -> "DataFrameReader": + """ + Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + schema : str + a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). + + """ + self._schema = schema + return self + + def option(self, key: str, value: "OptionalPrimitiveType") -> "DataFrameReader": + """ + Adds an input option for the underlying data source. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + key : str + The key for the option to set. key string is case-insensitive. + value + The value for the option to set. + + """ + self._options[key] = str(value) + return self + + def options(self, **options: "OptionalPrimitiveType") -> "DataFrameReader": + """ + Adds input options for the underlying data source. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + **options : dict + The dictionary of string keys and prmitive-type values. + """ + for k in options: + self.option(k, to_str(options[k])) + return self + + def load( + self, + path: Optional[str] = None, + format: Optional[str] = None, + schema: Optional[str] = None, + **options: "OptionalPrimitiveType", + ) -> "DataFrame": + """ + Loads data from a data source and returns it as a :class:`DataFrame`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + path : str or list, optional + optional string or a list of string for file-system backed data sources. + format : str, optional + optional string for format of the data source. + schema : str, optional + optional DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). + **options : dict + all other string options + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + self.option("path", path) + + plan = DataSource(format=self._format, schema=self._schema, options=self._options) + df = DataFrame.withPlan(plan, self._client) + return df def table(self, tableName: str) -> "DataFrame": df = DataFrame.withPlan(Read(tableName), self._client) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 1a59e7d596e..de300946932 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -16,6 +16,7 @@ # from typing import Any import unittest +import shutil import tempfile import pandas @@ -24,6 +25,7 @@ from pyspark.sql import SparkSession, Row from pyspark.sql.connect.client import RemoteSparkSession from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit +from pyspark.sql.dataframe import DataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase @@ -35,6 +37,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): connect: RemoteSparkSession tbl_name: str + df_text: "DataFrame" @classmethod def setUpClass(cls: Any): @@ -44,7 +47,9 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): # Create the new Spark Session cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.testDataStr = [Row(key=str(i)) for i in range(100)] cls.df = cls.sc.parallelize(cls.testData).toDF() + cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF() cls.tbl_name = "test_connect_basic_table_1" @@ -101,6 +106,21 @@ class SparkConnectTests(SparkConnectSQLTestCase): res = pandas.DataFrame(data={"id": [0, 30, 60, 90]}) self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}") + def test_simple_datasource_read(self) -> None: + writeDf = self.df_text + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + writeDf.write.text(tmpPath) + + readDf = self.connect.read.format("text").schema("id STRING").load(path=tmpPath) + expectResult = writeDf.collect() + pandasResult = readDf.toPandas() + if pandasResult is None: + self.assertTrue(False, "Empty pandas dataframe") + else: + actualResult = pandasResult.values.tolist() + self.assertEqual(len(expectResult), len(actualResult)) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index c547000bdcf..96bbb8aa834 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -18,6 +18,7 @@ import unittest from pyspark.testing.connectutils import PlanOnlyTestFixture import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.function_builder import UserDefinedFunction, udf from pyspark.sql.types import StringType @@ -48,6 +49,18 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): plan = df.alias("table_alias")._plan.to_proto(self.connect) self.assertEqual(plan.root.common.alias, "table_alias") + def test_datasource_read(self): + reader = DataFrameReader(self.connect) + df = reader.load(path="test_path", format="text", schema="id INT", op1="opv", op2="opv2") + plan = df._plan.to_proto(self.connect) + data_source = plan.root.read.data_source + self.assertEqual(data_source.format, "text") + self.assertEqual(data_source.schema, "id INT") + self.assertEqual(len(data_source.options), 3) + self.assertEqual(data_source.options.get("path"), "test_path") + self.assertEqual(data_source.options.get("op1"), "opv") + self.assertEqual(data_source.options.get("op2"), "opv2") + def test_simple_udf(self): u = udf(lambda x: "Martin", StringType()) self.assertIsNotNone(u) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org