This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 23e3c9b7c2f [SPARK-41828][CONNECT][PYTHON] Make `createDataFrame` 
support empty dataframe
23e3c9b7c2f is described below

commit 23e3c9b7c2f08c5350992934cf660de6d2793982
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Jan 4 17:45:46 2023 +0900

    [SPARK-41828][CONNECT][PYTHON] Make `createDataFrame` support empty 
dataframe
    
    ### What changes were proposed in this pull request?
    Make `createDataFrame` support empty dataframe:
    
    ```
    In [24]: spark.createDataFrame([], schema="x STRING, y INTEGER")
    Out[24]: DataFrame[x: string, y: int]
    ```
    
    ### Why are the changes needed?
    to be consistent with PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT and enabled doctests
    
    Closes #39379 from zhengruifeng/connect_fix_41828.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  18 ++--
 .../sql/connect/planner/SparkConnectPlanner.scala  |  68 ++++++++-----
 python/pyspark/sql/connect/dataframe.py            |   3 -
 python/pyspark/sql/connect/plan.py                 |  34 ++++---
 python/pyspark/sql/connect/proto/relations_pb2.py  | 110 ++++++++++-----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  41 ++++----
 python/pyspark/sql/connect/session.py              |  32 ++++--
 .../sql/tests/connect/test_connect_basic.py        |  28 ++++++
 8 files changed, 193 insertions(+), 141 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 51981714ded..c0f22dd4576 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -328,20 +328,16 @@ message Deduplicate {
 
 // A relation that does not need to be qualified by name.
 message LocalRelation {
-  // Local collection data serialized into Arrow IPC streaming format which 
contains
+  // (Optional) Local collection data serialized into Arrow IPC streaming 
format which contains
   // the schema of the data.
-  bytes data = 1;
+  optional bytes data = 1;
 
-  // (Optional) The user provided schema.
+  // (Optional) The schema of local data.
+  // It should be either a DDL-formatted type string or a JSON string.
   //
-  // The Sever side will update the column names and data types according to 
this schema.
-  oneof schema {
-
-    DataType datatype = 2;
-
-    // Server will use Catalyst parser to parse this string to DataType.
-    string datatype_str = 3;
-  }
+  // The server side will update the column names and data types according to 
this schema.
+  // If the 'data' is not provided, then this schema will be required.
+  optional string schema = 2;
 }
 
 // Relation of type [[Sample]] that samples a fraction of the dataset.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 754bb7ced9e..b4c882541e0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -571,47 +571,61 @@ class SparkConnectPlanner(session: SparkSession) {
     try {
       parser.parseTableSchema(sqlText)
     } catch {
-      case _: ParseException =>
+      case e: ParseException =>
         try {
           parser.parseDataType(sqlText)
         } catch {
           case _: ParseException =>
-            parser.parseDataType(s"struct<${sqlText.trim}>")
+            try {
+              parser.parseDataType(s"struct<${sqlText.trim}>")
+            } catch {
+              case _: ParseException =>
+                throw e
+            }
         }
     }
   }
 
   private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
-    val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
-      Iterator(rel.getData.toByteArray),
-      TaskContext.get())
-    if (structType == null) {
-      throw InvalidPlanInput(s"Input data for LocalRelation does not produce a 
schema.")
+    var schema: StructType = null
+    if (rel.hasSchema) {
+      val schemaType = DataType.parseTypeWithFallback(
+        rel.getSchema,
+        parseDatatypeString,
+        fallbackParser = DataType.fromJson)
+      schema = schemaType match {
+        case s: StructType => s
+        case d => StructType(Seq(StructField("value", d)))
+      }
     }
-    val attributes = structType.toAttributes
-    val proj = UnsafeProjection.create(attributes, attributes)
-    val relation = logical.LocalRelation(attributes, rows.map(r => 
proj(r).copy()).toSeq)
 
-    if (!rel.hasDatatype && !rel.hasDatatypeStr) {
-      return relation
-    }
+    if (rel.hasData) {
+      val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
+        Iterator(rel.getData.toByteArray),
+        TaskContext.get())
+      if (structType == null) {
+        throw InvalidPlanInput(s"Input data for LocalRelation does not produce 
a schema.")
+      }
+      val attributes = structType.toAttributes
+      val proj = UnsafeProjection.create(attributes, attributes)
+      val relation = logical.LocalRelation(attributes, rows.map(r => 
proj(r).copy()).toSeq)
 
-    val schemaType = if (rel.hasDatatype) {
-      DataTypeProtoConverter.toCatalystType(rel.getDatatype)
+      if (schema == null) {
+        relation
+      } else {
+        Dataset
+          .ofRows(session, logicalPlan = relation)
+          .toDF(schema.names: _*)
+          .to(schema)
+          .logicalPlan
+      }
     } else {
-      parseDatatypeString(rel.getDatatypeStr)
-    }
-
-    val schemaStruct = schemaType match {
-      case s: StructType => s
-      case d => StructType(Seq(StructField("value", d)))
+      if (schema == null) {
+        throw InvalidPlanInput(
+          s"Schema for LocalRelation is required when the input data is not 
provided.")
+      }
+      LocalRelation(schema.toAttributes, data = Seq.empty)
     }
-
-    Dataset
-      .ofRows(session, logicalPlan = relation)
-      .toDF(schemaStruct.names: _*)
-      .to(schemaStruct)
-      .logicalPlan
   }
 
   private def transformReadRel(rel: proto.Read): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 57c9e801c22..646cc5ced9a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1426,9 +1426,6 @@ def _test() -> None:
         # TODO(SPARK-41827): groupBy requires all cols be Column or str
         del pyspark.sql.connect.dataframe.DataFrame.groupBy.__doc__
 
-        # TODO(SPARK-41828): Implement creating empty DataFrame
-        del pyspark.sql.connect.dataframe.DataFrame.isEmpty.__doc__
-
         # TODO(SPARK-41829): Add Dataframe sort ordering
         del pyspark.sql.connect.dataframe.DataFrame.sort.__doc__
         del 
pyspark.sql.connect.dataframe.DataFrame.sortWithinPartitions.__doc__
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 48a8fa598e7..1f4e4192fdf 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -270,30 +270,34 @@ class LocalRelation(LogicalPlan):
 
     def __init__(
         self,
-        table: "pa.Table",
-        schema: Optional[Union[DataType, str]] = None,
+        table: Optional["pa.Table"],
+        schema: Optional[str] = None,
     ) -> None:
         super().__init__(None)
-        assert table is not None and isinstance(table, pa.Table)
+
+        if table is None:
+            assert schema is not None
+        else:
+            assert isinstance(table, pa.Table)
+
+        assert schema is None or isinstance(schema, str)
+
         self._table = table
 
-        if schema is not None:
-            assert isinstance(schema, (DataType, str))
         self._schema = schema
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        sink = pa.BufferOutputStream()
-        with pa.ipc.new_stream(sink, self._table.schema) as writer:
-            for b in self._table.to_batches():
-                writer.write_batch(b)
-
         plan = proto.Relation()
-        plan.local_relation.data = sink.getvalue().to_pybytes()
+
+        if self._table is not None:
+            sink = pa.BufferOutputStream()
+            with pa.ipc.new_stream(sink, self._table.schema) as writer:
+                for b in self._table.to_batches():
+                    writer.write_batch(b)
+            plan.local_relation.data = sink.getvalue().to_pybytes()
+
         if self._schema is not None:
-            if isinstance(self._schema, DataType):
-                
plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema))
-            elif isinstance(self._schema, str):
-                plan.local_relation.datatype_str = self._schema
+            plan.local_relation.schema = self._schema
         return plan
 
     def print(self, indent: int = 0) -> str:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index cf0f2eb3513..9e230c3d239 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -656,58 +656,58 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _DROP._serialized_end = 5290
     _DEDUPLICATE._serialized_start = 5293
     _DEDUPLICATE._serialized_end = 5464
-    _LOCALRELATION._serialized_start = 5467
-    _LOCALRELATION._serialized_end = 5604
-    _SAMPLE._serialized_start = 5607
-    _SAMPLE._serialized_end = 5880
-    _RANGE._serialized_start = 5883
-    _RANGE._serialized_end = 6028
-    _SUBQUERYALIAS._serialized_start = 6030
-    _SUBQUERYALIAS._serialized_end = 6144
-    _REPARTITION._serialized_start = 6147
-    _REPARTITION._serialized_end = 6289
-    _SHOWSTRING._serialized_start = 6292
-    _SHOWSTRING._serialized_end = 6434
-    _STATSUMMARY._serialized_start = 6436
-    _STATSUMMARY._serialized_end = 6528
-    _STATDESCRIBE._serialized_start = 6530
-    _STATDESCRIBE._serialized_end = 6611
-    _STATCROSSTAB._serialized_start = 6613
-    _STATCROSSTAB._serialized_end = 6714
-    _STATCOV._serialized_start = 6716
-    _STATCOV._serialized_end = 6812
-    _STATCORR._serialized_start = 6815
-    _STATCORR._serialized_end = 6952
-    _STATAPPROXQUANTILE._serialized_start = 6955
-    _STATAPPROXQUANTILE._serialized_end = 7119
-    _STATFREQITEMS._serialized_start = 7121
-    _STATFREQITEMS._serialized_end = 7246
-    _STATSAMPLEBY._serialized_start = 7249
-    _STATSAMPLEBY._serialized_end = 7558
-    _STATSAMPLEBY_FRACTION._serialized_start = 7450
-    _STATSAMPLEBY_FRACTION._serialized_end = 7549
-    _NAFILL._serialized_start = 7561
-    _NAFILL._serialized_end = 7695
-    _NADROP._serialized_start = 7698
-    _NADROP._serialized_end = 7832
-    _NAREPLACE._serialized_start = 7835
-    _NAREPLACE._serialized_end = 8131
-    _NAREPLACE_REPLACEMENT._serialized_start = 7990
-    _NAREPLACE_REPLACEMENT._serialized_end = 8131
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8133
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8247
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8250
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8509
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
8442
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8509
-    _WITHCOLUMNS._serialized_start = 8512
-    _WITHCOLUMNS._serialized_end = 8643
-    _HINT._serialized_start = 8646
-    _HINT._serialized_end = 8786
-    _UNPIVOT._serialized_start = 8789
-    _UNPIVOT._serialized_end = 9035
-    _TOSCHEMA._serialized_start = 9037
-    _TOSCHEMA._serialized_end = 9143
-    _REPARTITIONBYEXPRESSION._serialized_start = 9146
-    _REPARTITIONBYEXPRESSION._serialized_end = 9349
+    _LOCALRELATION._serialized_start = 5466
+    _LOCALRELATION._serialized_end = 5555
+    _SAMPLE._serialized_start = 5558
+    _SAMPLE._serialized_end = 5831
+    _RANGE._serialized_start = 5834
+    _RANGE._serialized_end = 5979
+    _SUBQUERYALIAS._serialized_start = 5981
+    _SUBQUERYALIAS._serialized_end = 6095
+    _REPARTITION._serialized_start = 6098
+    _REPARTITION._serialized_end = 6240
+    _SHOWSTRING._serialized_start = 6243
+    _SHOWSTRING._serialized_end = 6385
+    _STATSUMMARY._serialized_start = 6387
+    _STATSUMMARY._serialized_end = 6479
+    _STATDESCRIBE._serialized_start = 6481
+    _STATDESCRIBE._serialized_end = 6562
+    _STATCROSSTAB._serialized_start = 6564
+    _STATCROSSTAB._serialized_end = 6665
+    _STATCOV._serialized_start = 6667
+    _STATCOV._serialized_end = 6763
+    _STATCORR._serialized_start = 6766
+    _STATCORR._serialized_end = 6903
+    _STATAPPROXQUANTILE._serialized_start = 6906
+    _STATAPPROXQUANTILE._serialized_end = 7070
+    _STATFREQITEMS._serialized_start = 7072
+    _STATFREQITEMS._serialized_end = 7197
+    _STATSAMPLEBY._serialized_start = 7200
+    _STATSAMPLEBY._serialized_end = 7509
+    _STATSAMPLEBY_FRACTION._serialized_start = 7401
+    _STATSAMPLEBY_FRACTION._serialized_end = 7500
+    _NAFILL._serialized_start = 7512
+    _NAFILL._serialized_end = 7646
+    _NADROP._serialized_start = 7649
+    _NADROP._serialized_end = 7783
+    _NAREPLACE._serialized_start = 7786
+    _NAREPLACE._serialized_end = 8082
+    _NAREPLACE_REPLACEMENT._serialized_start = 7941
+    _NAREPLACE_REPLACEMENT._serialized_end = 8082
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8084
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8198
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8201
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8460
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
8393
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8460
+    _WITHCOLUMNS._serialized_start = 8463
+    _WITHCOLUMNS._serialized_end = 8594
+    _HINT._serialized_start = 8597
+    _HINT._serialized_end = 8737
+    _UNPIVOT._serialized_start = 8740
+    _UNPIVOT._serialized_end = 8986
+    _TOSCHEMA._serialized_start = 8988
+    _TOSCHEMA._serialized_end = 9094
+    _REPARTITIONBYEXPRESSION._serialized_start = 9097
+    _REPARTITIONBYEXPRESSION._serialized_end = 9300
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 7e63d363277..811f005d24b 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1268,45 +1268,44 @@ class LocalRelation(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
     DATA_FIELD_NUMBER: builtins.int
-    DATATYPE_FIELD_NUMBER: builtins.int
-    DATATYPE_STR_FIELD_NUMBER: builtins.int
+    SCHEMA_FIELD_NUMBER: builtins.int
     data: builtins.bytes
-    """Local collection data serialized into Arrow IPC streaming format which 
contains
+    """(Optional) Local collection data serialized into Arrow IPC streaming 
format which contains
     the schema of the data.
     """
-    @property
-    def datatype(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
-    datatype_str: builtins.str
-    """Server will use Catalyst parser to parse this string to DataType."""
+    schema: builtins.str
+    """(Optional) The schema of local data.
+    It should be either a DDL-formatted type string or a JSON string.
+
+    The server side will update the column names and data types according to 
this schema.
+    If the 'data' is not provided, then this schema will be required.
+    """
     def __init__(
         self,
         *,
-        data: builtins.bytes = ...,
-        datatype: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
-        datatype_str: builtins.str = ...,
+        data: builtins.bytes | None = ...,
+        schema: builtins.str | None = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
-            "datatype", b"datatype", "datatype_str", b"datatype_str", 
"schema", b"schema"
+            "_data", b"_data", "_schema", b"_schema", "data", b"data", 
"schema", b"schema"
         ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "data",
-            b"data",
-            "datatype",
-            b"datatype",
-            "datatype_str",
-            b"datatype_str",
-            "schema",
-            b"schema",
+            "_data", b"_data", "_schema", b"_schema", "data", b"data", 
"schema", b"schema"
         ],
     ) -> None: ...
+    @typing.overload
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_data", b"_data"]
+    ) -> typing_extensions.Literal["data"] | None: ...
+    @typing.overload
     def WhichOneof(
-        self, oneof_group: typing_extensions.Literal["schema", b"schema"]
-    ) -> typing_extensions.Literal["datatype", "datatype_str"] | None: ...
+        self, oneof_group: typing_extensions.Literal["_schema", b"_schema"]
+    ) -> typing_extensions.Literal["schema"] | None: ...
 
 global___LocalRelation = LocalRelation
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index a5d778e9c0e..09ad58fa3e0 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -31,6 +31,7 @@ from pyspark.sql.types import (
     Row,
     DataType,
     StructType,
+    AtomicType,
 )
 from pyspark.sql.utils import to_str
 
@@ -177,20 +178,18 @@ class SparkSession:
     def createDataFrame(
         self,
         data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]],
-        schema: Optional[Union[StructType, str, List[str], Tuple[str, ...]]] = 
None,
+        schema: Optional[Union[AtomicType, StructType, str, List[str], 
Tuple[str, ...]]] = None,
     ) -> "DataFrame":
         assert data is not None
         if isinstance(data, DataFrame):
             raise TypeError("data is already a DataFrame")
-        if isinstance(data, Sized) and len(data) == 0:
-            raise ValueError("Input data cannot be empty")
 
         table: Optional[pa.Table] = None
-        _schema: Optional[StructType] = None
+        _schema: Optional[Union[AtomicType, StructType]] = None
         _schema_str: Optional[str] = None
         _cols: Optional[List[str]] = None
 
-        if isinstance(schema, StructType):
+        if isinstance(schema, (AtomicType, StructType)):
             _schema = schema
 
         elif isinstance(schema, str):
@@ -200,6 +199,14 @@ class SparkSession:
             # Must re-encode any unicode strings to be consistent with 
StructField names
             _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x 
in schema]
 
+        if isinstance(data, Sized) and len(data) == 0:
+            if _schema is not None:
+                return DataFrame.withPlan(LocalRelation(table=None, 
schema=_schema.json()), self)
+            elif _schema_str is not None:
+                return DataFrame.withPlan(LocalRelation(table=None, 
schema=_schema_str), self)
+            else:
+                raise ValueError("can not infer schema from empty dataset")
+
         if isinstance(data, pd.DataFrame):
             table = pa.Table.from_pandas(data)
 
@@ -253,8 +260,10 @@ class SparkSession:
                         _cols = ["_%s" % i for i in range(1, len(_data[0]) + 
1)]
                     else:
                         _cols = ["_1"]
-                else:
+                elif isinstance(_schema, StructType):
                     _cols = _schema.names
+                else:
+                    _cols = ["value"]
 
             if isinstance(_data[0], Row):
                 table = pa.Table.from_pylist([row.asDict(recursive=True) for 
row in _data])
@@ -268,19 +277,24 @@ class SparkSession:
 
         # Validate number of columns
         num_cols = table.shape[1]
-        if _schema is not None and len(_schema.fields) != num_cols:
+        if (
+            _schema is not None
+            and isinstance(_schema, StructType)
+            and len(_schema.fields) != num_cols
+        ):
             raise ValueError(
                 f"Length mismatch: Expected axis has {num_cols} elements, "
                 f"new values have {len(_schema.fields)} elements"
             )
-        elif _cols is not None and len(_cols) != num_cols:
+
+        if _cols is not None and len(_cols) != num_cols:
             raise ValueError(
                 f"Length mismatch: Expected axis has {num_cols} elements, "
                 f"new values have {len(_cols)} elements"
             )
 
         if _schema is not None:
-            return DataFrame.withPlan(LocalRelation(table, schema=_schema), 
self)
+            return DataFrame.withPlan(LocalRelation(table, 
schema=_schema.json()), self)
         elif _schema_str is not None:
             return DataFrame.withPlan(LocalRelation(table, 
schema=_schema_str), self)
         elif _cols is not None and len(_cols) > 0:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index e82dc7f7f76..fe6c2c65e25 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -525,6 +525,34 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             sdf.select(SF.pmod("a", "b")).toPandas(),
         )
 
+    def test_create_empty_df(self):
+        for schema in [
+            "STRING",
+            "x STRING",
+            "x STRING, y INTEGER",
+            StringType(),
+            StructType(
+                [
+                    StructField("x", StringType(), True),
+                    StructField("y", IntegerType(), True),
+                ]
+            ),
+        ]:
+            print(schema)
+            print(schema)
+            print(schema)
+            cdf = self.connect.createDataFrame(data=[], schema=schema)
+            sdf = self.spark.createDataFrame(data=[], schema=schema)
+
+            self.assert_eq(cdf.toPandas(), sdf.toPandas())
+
+        # check error
+        with self.assertRaisesRegex(
+            ValueError,
+            "can not infer schema from empty dataset",
+        ):
+            self.connect.createDataFrame(data=[])
+
     def test_simple_explain_string(self):
         df = self.connect.read.table(self.tbl_name).limit(10)
         result = df._explain_string()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to