This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 23e3c9b7c2f [SPARK-41828][CONNECT][PYTHON] Make `createDataFrame`
support empty dataframe
23e3c9b7c2f is described below
commit 23e3c9b7c2f08c5350992934cf660de6d2793982
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jan 4 17:45:46 2023 +0900
[SPARK-41828][CONNECT][PYTHON] Make `createDataFrame` support empty
dataframe
### What changes were proposed in this pull request?
Make `createDataFrame` support empty dataframe:
```
In [24]: spark.createDataFrame([], schema="x STRING, y INTEGER")
Out[24]: DataFrame[x: string, y: int]
```
### Why are the changes needed?
to be consistent with PySpark
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added UT and enabled doctests
Closes #39379 from zhengruifeng/connect_fix_41828.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 18 ++--
.../sql/connect/planner/SparkConnectPlanner.scala | 68 ++++++++-----
python/pyspark/sql/connect/dataframe.py | 3 -
python/pyspark/sql/connect/plan.py | 34 ++++---
python/pyspark/sql/connect/proto/relations_pb2.py | 110 ++++++++++-----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 41 ++++----
python/pyspark/sql/connect/session.py | 32 ++++--
.../sql/tests/connect/test_connect_basic.py | 28 ++++++
8 files changed, 193 insertions(+), 141 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 51981714ded..c0f22dd4576 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -328,20 +328,16 @@ message Deduplicate {
// A relation that does not need to be qualified by name.
message LocalRelation {
- // Local collection data serialized into Arrow IPC streaming format which
contains
+ // (Optional) Local collection data serialized into Arrow IPC streaming
format which contains
// the schema of the data.
- bytes data = 1;
+ optional bytes data = 1;
- // (Optional) The user provided schema.
+ // (Optional) The schema of local data.
+ // It should be either a DDL-formatted type string or a JSON string.
//
- // The Sever side will update the column names and data types according to
this schema.
- oneof schema {
-
- DataType datatype = 2;
-
- // Server will use Catalyst parser to parse this string to DataType.
- string datatype_str = 3;
- }
+ // The server side will update the column names and data types according to
this schema.
+ // If the 'data' is not provided, then this schema will be required.
+ optional string schema = 2;
}
// Relation of type [[Sample]] that samples a fraction of the dataset.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 754bb7ced9e..b4c882541e0 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -571,47 +571,61 @@ class SparkConnectPlanner(session: SparkSession) {
try {
parser.parseTableSchema(sqlText)
} catch {
- case _: ParseException =>
+ case e: ParseException =>
try {
parser.parseDataType(sqlText)
} catch {
case _: ParseException =>
- parser.parseDataType(s"struct<${sqlText.trim}>")
+ try {
+ parser.parseDataType(s"struct<${sqlText.trim}>")
+ } catch {
+ case _: ParseException =>
+ throw e
+ }
}
}
}
private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
- val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
- Iterator(rel.getData.toByteArray),
- TaskContext.get())
- if (structType == null) {
- throw InvalidPlanInput(s"Input data for LocalRelation does not produce a
schema.")
+ var schema: StructType = null
+ if (rel.hasSchema) {
+ val schemaType = DataType.parseTypeWithFallback(
+ rel.getSchema,
+ parseDatatypeString,
+ fallbackParser = DataType.fromJson)
+ schema = schemaType match {
+ case s: StructType => s
+ case d => StructType(Seq(StructField("value", d)))
+ }
}
- val attributes = structType.toAttributes
- val proj = UnsafeProjection.create(attributes, attributes)
- val relation = logical.LocalRelation(attributes, rows.map(r =>
proj(r).copy()).toSeq)
- if (!rel.hasDatatype && !rel.hasDatatypeStr) {
- return relation
- }
+ if (rel.hasData) {
+ val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
+ Iterator(rel.getData.toByteArray),
+ TaskContext.get())
+ if (structType == null) {
+ throw InvalidPlanInput(s"Input data for LocalRelation does not produce
a schema.")
+ }
+ val attributes = structType.toAttributes
+ val proj = UnsafeProjection.create(attributes, attributes)
+ val relation = logical.LocalRelation(attributes, rows.map(r =>
proj(r).copy()).toSeq)
- val schemaType = if (rel.hasDatatype) {
- DataTypeProtoConverter.toCatalystType(rel.getDatatype)
+ if (schema == null) {
+ relation
+ } else {
+ Dataset
+ .ofRows(session, logicalPlan = relation)
+ .toDF(schema.names: _*)
+ .to(schema)
+ .logicalPlan
+ }
} else {
- parseDatatypeString(rel.getDatatypeStr)
- }
-
- val schemaStruct = schemaType match {
- case s: StructType => s
- case d => StructType(Seq(StructField("value", d)))
+ if (schema == null) {
+ throw InvalidPlanInput(
+ s"Schema for LocalRelation is required when the input data is not
provided.")
+ }
+ LocalRelation(schema.toAttributes, data = Seq.empty)
}
-
- Dataset
- .ofRows(session, logicalPlan = relation)
- .toDF(schemaStruct.names: _*)
- .to(schemaStruct)
- .logicalPlan
}
private def transformReadRel(rel: proto.Read): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 57c9e801c22..646cc5ced9a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1426,9 +1426,6 @@ def _test() -> None:
# TODO(SPARK-41827): groupBy requires all cols be Column or str
del pyspark.sql.connect.dataframe.DataFrame.groupBy.__doc__
- # TODO(SPARK-41828): Implement creating empty DataFrame
- del pyspark.sql.connect.dataframe.DataFrame.isEmpty.__doc__
-
# TODO(SPARK-41829): Add Dataframe sort ordering
del pyspark.sql.connect.dataframe.DataFrame.sort.__doc__
del
pyspark.sql.connect.dataframe.DataFrame.sortWithinPartitions.__doc__
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 48a8fa598e7..1f4e4192fdf 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -270,30 +270,34 @@ class LocalRelation(LogicalPlan):
def __init__(
self,
- table: "pa.Table",
- schema: Optional[Union[DataType, str]] = None,
+ table: Optional["pa.Table"],
+ schema: Optional[str] = None,
) -> None:
super().__init__(None)
- assert table is not None and isinstance(table, pa.Table)
+
+ if table is None:
+ assert schema is not None
+ else:
+ assert isinstance(table, pa.Table)
+
+ assert schema is None or isinstance(schema, str)
+
self._table = table
- if schema is not None:
- assert isinstance(schema, (DataType, str))
self._schema = schema
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- sink = pa.BufferOutputStream()
- with pa.ipc.new_stream(sink, self._table.schema) as writer:
- for b in self._table.to_batches():
- writer.write_batch(b)
-
plan = proto.Relation()
- plan.local_relation.data = sink.getvalue().to_pybytes()
+
+ if self._table is not None:
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, self._table.schema) as writer:
+ for b in self._table.to_batches():
+ writer.write_batch(b)
+ plan.local_relation.data = sink.getvalue().to_pybytes()
+
if self._schema is not None:
- if isinstance(self._schema, DataType):
-
plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema))
- elif isinstance(self._schema, str):
- plan.local_relation.datatype_str = self._schema
+ plan.local_relation.schema = self._schema
return plan
def print(self, indent: int = 0) -> str:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index cf0f2eb3513..9e230c3d239 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -656,58 +656,58 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_DROP._serialized_end = 5290
_DEDUPLICATE._serialized_start = 5293
_DEDUPLICATE._serialized_end = 5464
- _LOCALRELATION._serialized_start = 5467
- _LOCALRELATION._serialized_end = 5604
- _SAMPLE._serialized_start = 5607
- _SAMPLE._serialized_end = 5880
- _RANGE._serialized_start = 5883
- _RANGE._serialized_end = 6028
- _SUBQUERYALIAS._serialized_start = 6030
- _SUBQUERYALIAS._serialized_end = 6144
- _REPARTITION._serialized_start = 6147
- _REPARTITION._serialized_end = 6289
- _SHOWSTRING._serialized_start = 6292
- _SHOWSTRING._serialized_end = 6434
- _STATSUMMARY._serialized_start = 6436
- _STATSUMMARY._serialized_end = 6528
- _STATDESCRIBE._serialized_start = 6530
- _STATDESCRIBE._serialized_end = 6611
- _STATCROSSTAB._serialized_start = 6613
- _STATCROSSTAB._serialized_end = 6714
- _STATCOV._serialized_start = 6716
- _STATCOV._serialized_end = 6812
- _STATCORR._serialized_start = 6815
- _STATCORR._serialized_end = 6952
- _STATAPPROXQUANTILE._serialized_start = 6955
- _STATAPPROXQUANTILE._serialized_end = 7119
- _STATFREQITEMS._serialized_start = 7121
- _STATFREQITEMS._serialized_end = 7246
- _STATSAMPLEBY._serialized_start = 7249
- _STATSAMPLEBY._serialized_end = 7558
- _STATSAMPLEBY_FRACTION._serialized_start = 7450
- _STATSAMPLEBY_FRACTION._serialized_end = 7549
- _NAFILL._serialized_start = 7561
- _NAFILL._serialized_end = 7695
- _NADROP._serialized_start = 7698
- _NADROP._serialized_end = 7832
- _NAREPLACE._serialized_start = 7835
- _NAREPLACE._serialized_end = 8131
- _NAREPLACE_REPLACEMENT._serialized_start = 7990
- _NAREPLACE_REPLACEMENT._serialized_end = 8131
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8133
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8247
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8250
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8509
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
8442
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8509
- _WITHCOLUMNS._serialized_start = 8512
- _WITHCOLUMNS._serialized_end = 8643
- _HINT._serialized_start = 8646
- _HINT._serialized_end = 8786
- _UNPIVOT._serialized_start = 8789
- _UNPIVOT._serialized_end = 9035
- _TOSCHEMA._serialized_start = 9037
- _TOSCHEMA._serialized_end = 9143
- _REPARTITIONBYEXPRESSION._serialized_start = 9146
- _REPARTITIONBYEXPRESSION._serialized_end = 9349
+ _LOCALRELATION._serialized_start = 5466
+ _LOCALRELATION._serialized_end = 5555
+ _SAMPLE._serialized_start = 5558
+ _SAMPLE._serialized_end = 5831
+ _RANGE._serialized_start = 5834
+ _RANGE._serialized_end = 5979
+ _SUBQUERYALIAS._serialized_start = 5981
+ _SUBQUERYALIAS._serialized_end = 6095
+ _REPARTITION._serialized_start = 6098
+ _REPARTITION._serialized_end = 6240
+ _SHOWSTRING._serialized_start = 6243
+ _SHOWSTRING._serialized_end = 6385
+ _STATSUMMARY._serialized_start = 6387
+ _STATSUMMARY._serialized_end = 6479
+ _STATDESCRIBE._serialized_start = 6481
+ _STATDESCRIBE._serialized_end = 6562
+ _STATCROSSTAB._serialized_start = 6564
+ _STATCROSSTAB._serialized_end = 6665
+ _STATCOV._serialized_start = 6667
+ _STATCOV._serialized_end = 6763
+ _STATCORR._serialized_start = 6766
+ _STATCORR._serialized_end = 6903
+ _STATAPPROXQUANTILE._serialized_start = 6906
+ _STATAPPROXQUANTILE._serialized_end = 7070
+ _STATFREQITEMS._serialized_start = 7072
+ _STATFREQITEMS._serialized_end = 7197
+ _STATSAMPLEBY._serialized_start = 7200
+ _STATSAMPLEBY._serialized_end = 7509
+ _STATSAMPLEBY_FRACTION._serialized_start = 7401
+ _STATSAMPLEBY_FRACTION._serialized_end = 7500
+ _NAFILL._serialized_start = 7512
+ _NAFILL._serialized_end = 7646
+ _NADROP._serialized_start = 7649
+ _NADROP._serialized_end = 7783
+ _NAREPLACE._serialized_start = 7786
+ _NAREPLACE._serialized_end = 8082
+ _NAREPLACE_REPLACEMENT._serialized_start = 7941
+ _NAREPLACE_REPLACEMENT._serialized_end = 8082
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8084
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8198
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8201
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8460
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
8393
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8460
+ _WITHCOLUMNS._serialized_start = 8463
+ _WITHCOLUMNS._serialized_end = 8594
+ _HINT._serialized_start = 8597
+ _HINT._serialized_end = 8737
+ _UNPIVOT._serialized_start = 8740
+ _UNPIVOT._serialized_end = 8986
+ _TOSCHEMA._serialized_start = 8988
+ _TOSCHEMA._serialized_end = 9094
+ _REPARTITIONBYEXPRESSION._serialized_start = 9097
+ _REPARTITIONBYEXPRESSION._serialized_end = 9300
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 7e63d363277..811f005d24b 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1268,45 +1268,44 @@ class LocalRelation(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DATA_FIELD_NUMBER: builtins.int
- DATATYPE_FIELD_NUMBER: builtins.int
- DATATYPE_STR_FIELD_NUMBER: builtins.int
+ SCHEMA_FIELD_NUMBER: builtins.int
data: builtins.bytes
- """Local collection data serialized into Arrow IPC streaming format which
contains
+ """(Optional) Local collection data serialized into Arrow IPC streaming
format which contains
the schema of the data.
"""
- @property
- def datatype(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
- datatype_str: builtins.str
- """Server will use Catalyst parser to parse this string to DataType."""
+ schema: builtins.str
+ """(Optional) The schema of local data.
+ It should be either a DDL-formatted type string or a JSON string.
+
+ The server side will update the column names and data types according to
this schema.
+ If the 'data' is not provided, then this schema will be required.
+ """
def __init__(
self,
*,
- data: builtins.bytes = ...,
- datatype: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
- datatype_str: builtins.str = ...,
+ data: builtins.bytes | None = ...,
+ schema: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
- "datatype", b"datatype", "datatype_str", b"datatype_str",
"schema", b"schema"
+ "_data", b"_data", "_schema", b"_schema", "data", b"data",
"schema", b"schema"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "data",
- b"data",
- "datatype",
- b"datatype",
- "datatype_str",
- b"datatype_str",
- "schema",
- b"schema",
+ "_data", b"_data", "_schema", b"_schema", "data", b"data",
"schema", b"schema"
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_data", b"_data"]
+ ) -> typing_extensions.Literal["data"] | None: ...
+ @typing.overload
def WhichOneof(
- self, oneof_group: typing_extensions.Literal["schema", b"schema"]
- ) -> typing_extensions.Literal["datatype", "datatype_str"] | None: ...
+ self, oneof_group: typing_extensions.Literal["_schema", b"_schema"]
+ ) -> typing_extensions.Literal["schema"] | None: ...
global___LocalRelation = LocalRelation
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index a5d778e9c0e..09ad58fa3e0 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -31,6 +31,7 @@ from pyspark.sql.types import (
Row,
DataType,
StructType,
+ AtomicType,
)
from pyspark.sql.utils import to_str
@@ -177,20 +178,18 @@ class SparkSession:
def createDataFrame(
self,
data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]],
- schema: Optional[Union[StructType, str, List[str], Tuple[str, ...]]] =
None,
+ schema: Optional[Union[AtomicType, StructType, str, List[str],
Tuple[str, ...]]] = None,
) -> "DataFrame":
assert data is not None
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")
- if isinstance(data, Sized) and len(data) == 0:
- raise ValueError("Input data cannot be empty")
table: Optional[pa.Table] = None
- _schema: Optional[StructType] = None
+ _schema: Optional[Union[AtomicType, StructType]] = None
_schema_str: Optional[str] = None
_cols: Optional[List[str]] = None
- if isinstance(schema, StructType):
+ if isinstance(schema, (AtomicType, StructType)):
_schema = schema
elif isinstance(schema, str):
@@ -200,6 +199,14 @@ class SparkSession:
# Must re-encode any unicode strings to be consistent with
StructField names
_cols = [x.encode("utf-8") if not isinstance(x, str) else x for x
in schema]
+ if isinstance(data, Sized) and len(data) == 0:
+ if _schema is not None:
+ return DataFrame.withPlan(LocalRelation(table=None,
schema=_schema.json()), self)
+ elif _schema_str is not None:
+ return DataFrame.withPlan(LocalRelation(table=None,
schema=_schema_str), self)
+ else:
+ raise ValueError("can not infer schema from empty dataset")
+
if isinstance(data, pd.DataFrame):
table = pa.Table.from_pandas(data)
@@ -253,8 +260,10 @@ class SparkSession:
_cols = ["_%s" % i for i in range(1, len(_data[0]) +
1)]
else:
_cols = ["_1"]
- else:
+ elif isinstance(_schema, StructType):
_cols = _schema.names
+ else:
+ _cols = ["value"]
if isinstance(_data[0], Row):
table = pa.Table.from_pylist([row.asDict(recursive=True) for
row in _data])
@@ -268,19 +277,24 @@ class SparkSession:
# Validate number of columns
num_cols = table.shape[1]
- if _schema is not None and len(_schema.fields) != num_cols:
+ if (
+ _schema is not None
+ and isinstance(_schema, StructType)
+ and len(_schema.fields) != num_cols
+ ):
raise ValueError(
f"Length mismatch: Expected axis has {num_cols} elements, "
f"new values have {len(_schema.fields)} elements"
)
- elif _cols is not None and len(_cols) != num_cols:
+
+ if _cols is not None and len(_cols) != num_cols:
raise ValueError(
f"Length mismatch: Expected axis has {num_cols} elements, "
f"new values have {len(_cols)} elements"
)
if _schema is not None:
- return DataFrame.withPlan(LocalRelation(table, schema=_schema),
self)
+ return DataFrame.withPlan(LocalRelation(table,
schema=_schema.json()), self)
elif _schema_str is not None:
return DataFrame.withPlan(LocalRelation(table,
schema=_schema_str), self)
elif _cols is not None and len(_cols) > 0:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index e82dc7f7f76..fe6c2c65e25 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -525,6 +525,34 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
sdf.select(SF.pmod("a", "b")).toPandas(),
)
+ def test_create_empty_df(self):
+ for schema in [
+ "STRING",
+ "x STRING",
+ "x STRING, y INTEGER",
+ StringType(),
+ StructType(
+ [
+ StructField("x", StringType(), True),
+ StructField("y", IntegerType(), True),
+ ]
+ ),
+ ]:
+ print(schema)
+ print(schema)
+ print(schema)
+ cdf = self.connect.createDataFrame(data=[], schema=schema)
+ sdf = self.spark.createDataFrame(data=[], schema=schema)
+
+ self.assert_eq(cdf.toPandas(), sdf.toPandas())
+
+ # check error
+ with self.assertRaisesRegex(
+ ValueError,
+ "can not infer schema from empty dataset",
+ ):
+ self.connect.createDataFrame(data=[])
+
def test_simple_explain_string(self):
df = self.connect.read.table(self.tbl_name).limit(10)
result = df._explain_string()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]