This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7143cbd0725 [SPARK-41440][CONNECT][PYTHON] Avoid the cache operator
for general Sample
7143cbd0725 is described below
commit 7143cbd072557c8ea231b378572e8a7554d8a3f5
Author: Jiaan Geng <[email protected]>
AuthorDate: Fri Dec 30 20:22:21 2022 +0800
[SPARK-41440][CONNECT][PYTHON] Avoid the cache operator for general Sample
### What changes were proposed in this pull request?
https://github.com/apache/spark/pull/39017 supported
`DataFrame.randomSplit`. But cache the Sample plan incorrectly.
### Why are the changes needed?
This PR avoid the cache operator for general `Sample`.
This PR also give a more suitable name `deterministic_order` to replace
`force_stable_sort`.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
Tests updated.
Closes #39240 from beliefer/SPARK-41440_followup2.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 7 +-
.../org/apache/spark/sql/connect/dsl/package.scala | 2 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 9 ++-
python/pyspark/sql/connect/dataframe.py | 2 +-
python/pyspark/sql/connect/plan.py | 6 +-
python/pyspark/sql/connect/proto/relations_pb2.py | 88 +++++++++++-----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 25 ++----
.../sql/tests/connect/test_connect_plan_only.py | 10 +--
8 files changed, 71 insertions(+), 78 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index afff04f8f0d..3bb0b362b27 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -358,9 +358,10 @@ message Sample {
// (Optional) The random seed.
optional int64 seed = 5;
- // (Optional) Explicitly sort the underlying plan to make the ordering
deterministic.
- // This flag is only used to randomly splits DataFrame with the provided
weights.
- optional bool force_stable_sort = 6;
+ // (Required) Explicitly sort the underlying plan to make the ordering
deterministic or cache it.
+ // This flag is true when invoking `dataframe.randomSplit` to randomly
splits DataFrame with the
+ // provided weights. Otherwise, it is false.
+ bool deterministic_order = 6;
}
// Relation of type [[Range]] that generates a sequence of integers.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 3bd713a9710..c4a5eac46c0 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -964,7 +964,7 @@ package object dsl {
.setUpperBound(x(1))
.setWithReplacement(false)
.setSeed(seed)
- .setForceStableSort(true)
+ .setDeterministicOrder(true)
.build())
.build()
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 7d6fdc2883e..a11ebd8b7d1 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -211,8 +211,9 @@ class SparkConnectPlanner(session: SparkSession) {
* wrap such fields into proto messages.
*/
private def transformSample(rel: proto.Sample): LogicalPlan = {
- val input = Dataset.ofRows(session, transformRelation(rel.getInput))
- val plan = if (rel.getForceStableSort) {
+ val plan = if (rel.getDeterministicOrder) {
+ val input = Dataset.ofRows(session, transformRelation(rel.getInput))
+
// It is possible that the underlying dataframe doesn't guarantee the
ordering of rows in its
// constituent partitions each time a split is materialized which could
result in
// overlapping splits. To prevent this, we explicitly sort each input
partition to make the
@@ -224,11 +225,11 @@ class SparkConnectPlanner(session: SparkSession) {
if (sortOrder.nonEmpty) {
Sort(sortOrder, global = false, input.logicalPlan)
} else {
+ input.cache()
input.logicalPlan
}
} else {
- input.cache()
- input.logicalPlan
+ transformRelation(rel.getInput)
}
Sample(
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 256e63122ab..018785b77b0 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -523,7 +523,7 @@ class DataFrame:
upper_bound=upperBound,
with_replacement=False,
seed=int(seed),
- force_stable_sort=True,
+ deterministic_order=True,
),
session=self._session,
)
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index e1b9fa0d0e4..069a1329e80 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -558,14 +558,14 @@ class Sample(LogicalPlan):
upper_bound: float,
with_replacement: bool,
seed: Optional[int],
- force_stable_sort: bool = False,
+ deterministic_order: bool = False,
) -> None:
super().__init__(child)
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.with_replacement = with_replacement
self.seed = seed
- self.force_stable_sort = force_stable_sort
+ self.deterministic_order = deterministic_order
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
@@ -576,7 +576,7 @@ class Sample(LogicalPlan):
plan.sample.with_replacement = self.with_replacement
if self.seed is not None:
plan.sample.seed = self.seed
- plan.sample.force_stable_sort = self.force_stable_sort
+ plan.sample.deterministic_order = self.deterministic_order
return plan
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 92a6d252a68..205caca07b1 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xa4\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xa4\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -612,47 +612,47 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_LOCALRELATION._serialized_start = 5266
_LOCALRELATION._serialized_end = 5403
_SAMPLE._serialized_start = 5406
- _SAMPLE._serialized_end = 5701
- _RANGE._serialized_start = 5704
- _RANGE._serialized_end = 5849
- _SUBQUERYALIAS._serialized_start = 5851
- _SUBQUERYALIAS._serialized_end = 5965
- _REPARTITION._serialized_start = 5968
- _REPARTITION._serialized_end = 6110
- _SHOWSTRING._serialized_start = 6113
- _SHOWSTRING._serialized_end = 6255
- _STATSUMMARY._serialized_start = 6257
- _STATSUMMARY._serialized_end = 6349
- _STATDESCRIBE._serialized_start = 6351
- _STATDESCRIBE._serialized_end = 6432
- _STATCROSSTAB._serialized_start = 6434
- _STATCROSSTAB._serialized_end = 6535
- _STATCOV._serialized_start = 6537
- _STATCOV._serialized_end = 6633
- _STATCORR._serialized_start = 6636
- _STATCORR._serialized_end = 6773
- _NAFILL._serialized_start = 6776
- _NAFILL._serialized_end = 6910
- _NADROP._serialized_start = 6913
- _NADROP._serialized_end = 7047
- _NAREPLACE._serialized_start = 7050
- _NAREPLACE._serialized_end = 7346
- _NAREPLACE_REPLACEMENT._serialized_start = 7205
- _NAREPLACE_REPLACEMENT._serialized_end = 7346
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7348
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7462
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7465
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7724
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
7657
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7724
- _WITHCOLUMNS._serialized_start = 7727
- _WITHCOLUMNS._serialized_end = 7858
- _HINT._serialized_start = 7861
- _HINT._serialized_end = 8001
- _UNPIVOT._serialized_start = 8004
- _UNPIVOT._serialized_end = 8250
- _TOSCHEMA._serialized_start = 8252
- _TOSCHEMA._serialized_end = 8358
- _REPARTITIONBYEXPRESSION._serialized_start = 8361
- _REPARTITIONBYEXPRESSION._serialized_end = 8564
+ _SAMPLE._serialized_end = 5679
+ _RANGE._serialized_start = 5682
+ _RANGE._serialized_end = 5827
+ _SUBQUERYALIAS._serialized_start = 5829
+ _SUBQUERYALIAS._serialized_end = 5943
+ _REPARTITION._serialized_start = 5946
+ _REPARTITION._serialized_end = 6088
+ _SHOWSTRING._serialized_start = 6091
+ _SHOWSTRING._serialized_end = 6233
+ _STATSUMMARY._serialized_start = 6235
+ _STATSUMMARY._serialized_end = 6327
+ _STATDESCRIBE._serialized_start = 6329
+ _STATDESCRIBE._serialized_end = 6410
+ _STATCROSSTAB._serialized_start = 6412
+ _STATCROSSTAB._serialized_end = 6513
+ _STATCOV._serialized_start = 6515
+ _STATCOV._serialized_end = 6611
+ _STATCORR._serialized_start = 6614
+ _STATCORR._serialized_end = 6751
+ _NAFILL._serialized_start = 6754
+ _NAFILL._serialized_end = 6888
+ _NADROP._serialized_start = 6891
+ _NADROP._serialized_end = 7025
+ _NAREPLACE._serialized_start = 7028
+ _NAREPLACE._serialized_end = 7324
+ _NAREPLACE_REPLACEMENT._serialized_start = 7183
+ _NAREPLACE_REPLACEMENT._serialized_end = 7324
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7326
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7440
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7443
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7702
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
7635
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7702
+ _WITHCOLUMNS._serialized_start = 7705
+ _WITHCOLUMNS._serialized_end = 7836
+ _HINT._serialized_start = 7839
+ _HINT._serialized_end = 7979
+ _UNPIVOT._serialized_start = 7982
+ _UNPIVOT._serialized_end = 8228
+ _TOSCHEMA._serialized_start = 8230
+ _TOSCHEMA._serialized_end = 8336
+ _REPARTITIONBYEXPRESSION._serialized_start = 8339
+ _REPARTITIONBYEXPRESSION._serialized_end = 8542
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 36b55ec4f6b..e83e64f64e8 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1293,7 +1293,7 @@ class Sample(google.protobuf.message.Message):
UPPER_BOUND_FIELD_NUMBER: builtins.int
WITH_REPLACEMENT_FIELD_NUMBER: builtins.int
SEED_FIELD_NUMBER: builtins.int
- FORCE_STABLE_SORT_FIELD_NUMBER: builtins.int
+ DETERMINISTIC_ORDER_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for a Sample."""
@@ -1305,9 +1305,10 @@ class Sample(google.protobuf.message.Message):
"""(Optional) Whether to sample with replacement."""
seed: builtins.int
"""(Optional) The random seed."""
- force_stable_sort: builtins.bool
- """(Optional) Explicitly sort the underlying plan to make the ordering
deterministic.
- This flag is only used to randomly splits DataFrame with the provided
weights.
+ deterministic_order: builtins.bool
+ """(Required) Explicitly sort the underlying plan to make the ordering
deterministic or cache it.
+ This flag is true when invoking `dataframe.randomSplit` to randomly splits
DataFrame with the
+ provided weights. Otherwise, it is false.
"""
def __init__(
self,
@@ -1317,19 +1318,15 @@ class Sample(google.protobuf.message.Message):
upper_bound: builtins.float = ...,
with_replacement: builtins.bool | None = ...,
seed: builtins.int | None = ...,
- force_stable_sort: builtins.bool | None = ...,
+ deterministic_order: builtins.bool = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
- "_force_stable_sort",
- b"_force_stable_sort",
"_seed",
b"_seed",
"_with_replacement",
b"_with_replacement",
- "force_stable_sort",
- b"force_stable_sort",
"input",
b"input",
"seed",
@@ -1341,14 +1338,12 @@ class Sample(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
- "_force_stable_sort",
- b"_force_stable_sort",
"_seed",
b"_seed",
"_with_replacement",
b"_with_replacement",
- "force_stable_sort",
- b"force_stable_sort",
+ "deterministic_order",
+ b"deterministic_order",
"input",
b"input",
"lower_bound",
@@ -1362,10 +1357,6 @@ class Sample(google.protobuf.message.Message):
],
) -> None: ...
@typing.overload
- def WhichOneof(
- self, oneof_group: typing_extensions.Literal["_force_stable_sort",
b"_force_stable_sort"]
- ) -> typing_extensions.Literal["force_stable_sort"] | None: ...
- @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]
) -> typing_extensions.Literal["seed"] | None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 240fd7d4d72..529e3ca3eda 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -251,21 +251,21 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.sample.upper_bound, 0.16666666666666666)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), True)
- self.assertEqual(plan.root.sample.force_stable_sort, True)
+ self.assertEqual(plan.root.sample.deterministic_order, True)
plan = relations[1]._plan.to_proto(self.connect)
self.assertEqual(plan.root.sample.lower_bound, 0.16666666666666666)
self.assertEqual(plan.root.sample.upper_bound, 0.5)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), True)
- self.assertEqual(plan.root.sample.force_stable_sort, True)
+ self.assertEqual(plan.root.sample.deterministic_order, True)
plan = relations[2]._plan.to_proto(self.connect)
self.assertEqual(plan.root.sample.lower_bound, 0.5)
self.assertEqual(plan.root.sample.upper_bound, 1.0)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), True)
- self.assertEqual(plan.root.sample.force_stable_sort, True)
+ self.assertEqual(plan.root.sample.deterministic_order, True)
relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0], 1)
checkRelations(relations)
@@ -326,7 +326,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.sample.upper_bound, 0.3)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), False)
- self.assertEqual(plan.root.sample.force_stable_sort, False)
+ self.assertEqual(plan.root.sample.deterministic_order, False)
plan = (
df.filter(df.col_name > 3)
@@ -337,7 +337,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.sample.upper_bound, 0.4)
self.assertEqual(plan.root.sample.with_replacement, True)
self.assertEqual(plan.root.sample.seed, -1)
- self.assertEqual(plan.root.sample.force_stable_sort, False)
+ self.assertEqual(plan.root.sample.deterministic_order, False)
def test_sort(self):
df = self.connect.readTable(table_name=self.tbl_name)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]