This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 01c7a46f24f [SPARK-40539][CONNECT] Initial DataFrame Read API parity 
for Spark Connect
01c7a46f24f is described below

commit 01c7a46f24fb4bb4287a184a3d69e0e5c904bc50
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Thu Oct 20 09:26:06 2022 +0800

    [SPARK-40539][CONNECT] Initial DataFrame Read API parity for Spark Connect
    
    ### What changes were proposed in this pull request?
    
    Add initial Read API for Spark Connect that allows setting schema, format, 
option and path, and then to read files into DataFrame.
    
    ### Why are the changes needed?
    
    PySpark readwriter API parity for Spark Connect
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    ### How was this patch tested?
    
    UT
    
    Closes #38086 from amaliujia/SPARK-40539.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  10 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  15 ++-
 .../connect/planner/SparkConnectPlannerSuite.scala |  11 ++
 python/pyspark/sql/connect/plan.py                 |  41 +++++++
 python/pyspark/sql/connect/proto/relations_pb2.py  |  78 +++++++-------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  56 +++++++++-
 python/pyspark/sql/connect/readwriter.py           | 118 ++++++++++++++++++++-
 .../sql/tests/connect/test_connect_basic.py        |  20 ++++
 .../sql/tests/connect/test_connect_plan_only.py    |  13 +++
 9 files changed, 320 insertions(+), 42 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 353fbebd046..eadedf495d3 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -67,11 +67,21 @@ message SQL {
 message Read {
   oneof read_type {
     NamedTable named_table = 1;
+    DataSource data_source = 2;
   }
 
   message NamedTable {
     string unparsed_identifier = 1;
   }
+
+  message DataSource {
+    // Required. Supported formats include: parquet, orc, text, json, parquet, 
csv, avro.
+    string format = 1;
+    // Optional. If not set, Spark will infer the schema.
+    string schema = 2;
+    // The key is case insensitive.
+    map<string, string> options = 3;
+  }
 }
 
 // Projection of a bag of expressions for a given input relation.
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6a6b5a15a08..450283a9b81 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, 
Attribute, AttributeRef
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, 
JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sample, 
SubqueryAlias}
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.types._
 
 final case class InvalidPlanInput(
@@ -112,7 +113,19 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
         } else {
           child
         }
-      case _ => throw InvalidPlanInput()
+      case proto.Read.ReadTypeCase.DATA_SOURCE =>
+        if (rel.getDataSource.getFormat == "") {
+          throw InvalidPlanInput("DataSource requires a format")
+        }
+        val localMap = 
CaseInsensitiveMap[String](rel.getDataSource.getOptionsMap.asScala.toMap)
+        val reader = session.read
+        reader.format(rel.getDataSource.getFormat)
+        localMap.foreach { case (key, value) => reader.option(key, value) }
+        if (rel.getDataSource.getSchema != null) {
+          reader.schema(rel.getDataSource.getSchema)
+        }
+        reader.load().queryExecution.analyzed
+      case _ => throw InvalidPlanInput("Does not support " + 
rel.getReadTypeCase.name())
     }
     baseRelation
   }
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index fc3d219ec6b..83bf76efce1 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -255,4 +255,15 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
     assert(res.nodeName == "Aggregate")
   }
 
+  test("Invalid DataSource") {
+    val dataSource = proto.Read.DataSource.newBuilder()
+
+    val e = intercept[InvalidPlanInput](
+      transform(
+        proto.Relation
+          .newBuilder()
+          .setRead(proto.Read.newBuilder().setDataSource(dataSource))
+          .build()))
+    assert(e.getMessage.contains("DataSource requires a format"))
+  }
 }
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 5fcd468924d..c564b71cdba 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -23,6 +23,7 @@ from typing import (
     Union,
     cast,
     TYPE_CHECKING,
+    Mapping,
 )
 
 import pyspark.sql.connect.proto as proto
@@ -111,6 +112,46 @@ class LogicalPlan(object):
         return self._child._repr_html_() if self._child is not None else ""
 
 
+class DataSource(LogicalPlan):
+    """A datasource with a format and optional a schema from which Spark reads 
data"""
+
+    def __init__(
+        self,
+        format: str = "",
+        schema: Optional[str] = None,
+        options: Optional[Mapping[str, str]] = None,
+    ) -> None:
+        super().__init__(None)
+        self.format = format
+        self.schema = schema
+        self.options = options
+
+    def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
+        plan = proto.Relation()
+        if self.format is not None:
+            plan.read.data_source.format = self.format
+        if self.schema is not None:
+            plan.read.data_source.schema = self.schema
+        if self.options is not None:
+            for k in self.options.keys():
+                v = self.options.get(k)
+                if v is not None:
+                    plan.read.data_source.options[k] = v
+        return plan
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+            <li>
+                <b>DataSource</b><br />
+                format: {self.format}
+                schema: {self.schema}
+                options: {self.options}
+            </li>
+        </ul>
+        """
+
+
 class Read(LogicalPlan):
     def __init__(self, table_name: str) -> None:
         super().__init__(None)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index c3b7b7ec2ea..b244cdf8dcb 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -41,6 +41,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
+    _READ_DATASOURCE_OPTIONSENTRY._options = None
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 82
     _RELATION._serialized_end = 801
     _UNKNOWN._serialized_start = 803
@@ -50,39 +52,43 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _SQL._serialized_start = 887
     _SQL._serialized_end = 914
     _READ._serialized_start = 917
-    _READ._serialized_end = 1066
-    _READ_NAMEDTABLE._serialized_start = 992
-    _READ_NAMEDTABLE._serialized_end = 1053
-    _PROJECT._serialized_start = 1068
-    _PROJECT._serialized_end = 1185
-    _FILTER._serialized_start = 1187
-    _FILTER._serialized_end = 1299
-    _JOIN._serialized_start = 1302
-    _JOIN._serialized_end = 1715
-    _JOIN_JOINTYPE._serialized_start = 1528
-    _JOIN_JOINTYPE._serialized_end = 1715
-    _UNION._serialized_start = 1718
-    _UNION._serialized_end = 1923
-    _UNION_UNIONTYPE._serialized_start = 1839
-    _UNION_UNIONTYPE._serialized_end = 1923
-    _LIMIT._serialized_start = 1925
-    _LIMIT._serialized_end = 2001
-    _OFFSET._serialized_start = 2003
-    _OFFSET._serialized_end = 2082
-    _AGGREGATE._serialized_start = 2085
-    _AGGREGATE._serialized_end = 2410
-    _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2314
-    _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2410
-    _SORT._serialized_start = 2413
-    _SORT._serialized_end = 2915
-    _SORT_SORTFIELD._serialized_start = 2533
-    _SORT_SORTFIELD._serialized_end = 2721
-    _SORT_SORTDIRECTION._serialized_start = 2723
-    _SORT_SORTDIRECTION._serialized_end = 2831
-    _SORT_SORTNULLS._serialized_start = 2833
-    _SORT_SORTNULLS._serialized_end = 2915
-    _LOCALRELATION._serialized_start = 2917
-    _LOCALRELATION._serialized_end = 3010
-    _SAMPLE._serialized_start = 3013
-    _SAMPLE._serialized_end = 3197
+    _READ._serialized_end = 1327
+    _READ_NAMEDTABLE._serialized_start = 1059
+    _READ_NAMEDTABLE._serialized_end = 1120
+    _READ_DATASOURCE._serialized_start = 1123
+    _READ_DATASOURCE._serialized_end = 1314
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1256
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1314
+    _PROJECT._serialized_start = 1329
+    _PROJECT._serialized_end = 1446
+    _FILTER._serialized_start = 1448
+    _FILTER._serialized_end = 1560
+    _JOIN._serialized_start = 1563
+    _JOIN._serialized_end = 1976
+    _JOIN_JOINTYPE._serialized_start = 1789
+    _JOIN_JOINTYPE._serialized_end = 1976
+    _UNION._serialized_start = 1979
+    _UNION._serialized_end = 2184
+    _UNION_UNIONTYPE._serialized_start = 2100
+    _UNION_UNIONTYPE._serialized_end = 2184
+    _LIMIT._serialized_start = 2186
+    _LIMIT._serialized_end = 2262
+    _OFFSET._serialized_start = 2264
+    _OFFSET._serialized_end = 2343
+    _AGGREGATE._serialized_start = 2346
+    _AGGREGATE._serialized_end = 2671
+    _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2575
+    _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2671
+    _SORT._serialized_start = 2674
+    _SORT._serialized_end = 3176
+    _SORT_SORTFIELD._serialized_start = 2794
+    _SORT_SORTFIELD._serialized_end = 2982
+    _SORT_SORTDIRECTION._serialized_start = 2984
+    _SORT_SORTDIRECTION._serialized_end = 3092
+    _SORT_SORTNULLS._serialized_start = 3094
+    _SORT_SORTNULLS._serialized_end = 3176
+    _LOCALRELATION._serialized_start = 3178
+    _LOCALRELATION._serialized_end = 3271
+    _SAMPLE._serialized_start = 3274
+    _SAMPLE._serialized_end = 3458
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 3354fc86f45..f0a8b6412b5 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -280,29 +280,79 @@ class Read(google.protobuf.message.Message):
             field_name: typing_extensions.Literal["unparsed_identifier", 
b"unparsed_identifier"],
         ) -> None: ...
 
+    class DataSource(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        class OptionsEntry(google.protobuf.message.Message):
+            DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+            KEY_FIELD_NUMBER: builtins.int
+            VALUE_FIELD_NUMBER: builtins.int
+            key: builtins.str
+            value: builtins.str
+            def __init__(
+                self,
+                *,
+                key: builtins.str = ...,
+                value: builtins.str = ...,
+            ) -> None: ...
+            def ClearField(
+                self, field_name: typing_extensions.Literal["key", b"key", 
"value", b"value"]
+            ) -> None: ...
+
+        FORMAT_FIELD_NUMBER: builtins.int
+        SCHEMA_FIELD_NUMBER: builtins.int
+        OPTIONS_FIELD_NUMBER: builtins.int
+        format: builtins.str
+        """Required. Supported formats include: parquet, orc, text, json, 
parquet, csv, avro."""
+        schema: builtins.str
+        """Optional. If not set, Spark will infer the schema."""
+        @property
+        def options(
+            self,
+        ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, 
builtins.str]:
+            """The key is case insensitive."""
+        def __init__(
+            self,
+            *,
+            format: builtins.str = ...,
+            schema: builtins.str = ...,
+            options: collections.abc.Mapping[builtins.str, builtins.str] | 
None = ...,
+        ) -> None: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "format", b"format", "options", b"options", "schema", b"schema"
+            ],
+        ) -> None: ...
+
     NAMED_TABLE_FIELD_NUMBER: builtins.int
+    DATA_SOURCE_FIELD_NUMBER: builtins.int
     @property
     def named_table(self) -> global___Read.NamedTable: ...
+    @property
+    def data_source(self) -> global___Read.DataSource: ...
     def __init__(
         self,
         *,
         named_table: global___Read.NamedTable | None = ...,
+        data_source: global___Read.DataSource | None = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
-            "named_table", b"named_table", "read_type", b"read_type"
+            "data_source", b"data_source", "named_table", b"named_table", 
"read_type", b"read_type"
         ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "named_table", b"named_table", "read_type", b"read_type"
+            "data_source", b"data_source", "named_table", b"named_table", 
"read_type", b"read_type"
         ],
     ) -> None: ...
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["read_type", b"read_type"]
-    ) -> typing_extensions.Literal["named_table"] | None: ...
+    ) -> typing_extensions.Literal["named_table", "data_source"] | None: ...
 
 global___Read = Read
 
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index 285e78e59ae..66e48eeab76 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -15,8 +15,16 @@
 # limitations under the License.
 #
 
+
+from typing import Dict, Optional
+
+from pyspark.sql.connect.column import PrimitiveType
 from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.plan import Read
+from pyspark.sql.connect.plan import Read, DataSource
+from pyspark.sql.utils import to_str
+
+
+OptionalPrimitiveType = Optional[PrimitiveType]
 
 from typing import TYPE_CHECKING
 
@@ -29,8 +37,114 @@ class DataFrameReader:
     TODO(SPARK-40539) Achieve parity with PySpark.
     """
 
-    def __init__(self, client: "RemoteSparkSession") -> None:
+    def __init__(self, client: "RemoteSparkSession"):
         self._client = client
+        self._format = ""
+        self._schema = ""
+        self._options: Dict[str, str] = {}
+
+    def format(self, source: str) -> "DataFrameReader":
+        """
+        Specifies the input data source format.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        source : str
+        string, name of the data source, e.g. 'json', 'parquet'.
+
+        """
+        self._format = source
+        return self
+
+    # TODO(SPARK-40539): support StructType in python client and support 
schema as StructType.
+    def schema(self, schema: str) -> "DataFrameReader":
+        """
+        Specifies the input schema.
+
+        Some data sources (e.g. JSON) can infer the input schema automatically 
from data.
+        By specifying the schema here, the underlying data source can skip the 
schema
+        inference step, and thus speed up data loading.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        schema : str
+        a DDL-formatted string
+        (For example ``col0 INT, col1 DOUBLE``).
+
+        """
+        self._schema = schema
+        return self
+
+    def option(self, key: str, value: "OptionalPrimitiveType") -> 
"DataFrameReader":
+        """
+        Adds an input option for the underlying data source.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        key : str
+            The key for the option to set. key string is case-insensitive.
+        value
+            The value for the option to set.
+
+        """
+        self._options[key] = str(value)
+        return self
+
+    def options(self, **options: "OptionalPrimitiveType") -> "DataFrameReader":
+        """
+        Adds input options for the underlying data source.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        **options : dict
+            The dictionary of string keys and prmitive-type values.
+        """
+        for k in options:
+            self.option(k, to_str(options[k]))
+        return self
+
+    def load(
+        self,
+        path: Optional[str] = None,
+        format: Optional[str] = None,
+        schema: Optional[str] = None,
+        **options: "OptionalPrimitiveType",
+    ) -> "DataFrame":
+        """
+        Loads data from a data source and returns it as a :class:`DataFrame`.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        path : str or list, optional
+            optional string or a list of string for file-system backed data 
sources.
+        format : str, optional
+            optional string for format of the data source.
+        schema : str, optional
+            optional DDL-formatted string (For example ``col0 INT, col1 
DOUBLE``).
+        **options : dict
+            all other string options
+        """
+        if format is not None:
+            self.format(format)
+        if schema is not None:
+            self.schema(schema)
+        self.options(**options)
+        if path is not None:
+            self.option("path", path)
+
+        plan = DataSource(format=self._format, schema=self._schema, 
options=self._options)
+        df = DataFrame.withPlan(plan, self._client)
+        return df
 
     def table(self, tableName: str) -> "DataFrame":
         df = DataFrame.withPlan(Read(tableName), self._client)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 1a59e7d596e..de300946932 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -16,6 +16,7 @@
 #
 from typing import Any
 import unittest
+import shutil
 import tempfile
 
 import pandas
@@ -24,6 +25,7 @@ from pyspark.sql import SparkSession, Row
 from pyspark.sql.connect.client import RemoteSparkSession
 from pyspark.sql.connect.function_builder import udf
 from pyspark.sql.connect.functions import lit
+from pyspark.sql.dataframe import DataFrame
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import ReusedPySparkTestCase
 
@@ -35,6 +37,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
 
     connect: RemoteSparkSession
     tbl_name: str
+    df_text: "DataFrame"
 
     @classmethod
     def setUpClass(cls: Any):
@@ -44,7 +47,9 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
         # Create the new Spark Session
         cls.spark = SparkSession(cls.sc)
         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+        cls.testDataStr = [Row(key=str(i)) for i in range(100)]
         cls.df = cls.sc.parallelize(cls.testData).toDF()
+        cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()
 
         cls.tbl_name = "test_connect_basic_table_1"
 
@@ -101,6 +106,21 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         res = pandas.DataFrame(data={"id": [0, 30, 60, 90]})
         self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}")
 
+    def test_simple_datasource_read(self) -> None:
+        writeDf = self.df_text
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        writeDf.write.text(tmpPath)
+
+        readDf = self.connect.read.format("text").schema("id 
STRING").load(path=tmpPath)
+        expectResult = writeDf.collect()
+        pandasResult = readDf.toPandas()
+        if pandasResult is None:
+            self.assertTrue(False, "Empty pandas dataframe")
+        else:
+            actualResult = pandasResult.values.tolist()
+            self.assertEqual(len(expectResult), len(actualResult))
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_basic import *  # noqa: F401
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index c547000bdcf..96bbb8aa834 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -18,6 +18,7 @@ import unittest
 
 from pyspark.testing.connectutils import PlanOnlyTestFixture
 import pyspark.sql.connect.proto as proto
+from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
 from pyspark.sql.types import StringType
 
@@ -48,6 +49,18 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         plan = df.alias("table_alias")._plan.to_proto(self.connect)
         self.assertEqual(plan.root.common.alias, "table_alias")
 
+    def test_datasource_read(self):
+        reader = DataFrameReader(self.connect)
+        df = reader.load(path="test_path", format="text", schema="id INT", 
op1="opv", op2="opv2")
+        plan = df._plan.to_proto(self.connect)
+        data_source = plan.root.read.data_source
+        self.assertEqual(data_source.format, "text")
+        self.assertEqual(data_source.schema, "id INT")
+        self.assertEqual(len(data_source.options), 3)
+        self.assertEqual(data_source.options.get("path"), "test_path")
+        self.assertEqual(data_source.options.get("op1"), "opv")
+        self.assertEqual(data_source.options.get("op2"), "opv2")
+
     def test_simple_udf(self):
         u = udf(lambda x: "Martin", StringType())
         self.assertIsNotNone(u)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to