This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7143cbd0725 [SPARK-41440][CONNECT][PYTHON] Avoid the cache operator for general Sample 7143cbd0725 is described below commit 7143cbd072557c8ea231b378572e8a7554d8a3f5 Author: Jiaan Geng <belie...@163.com> AuthorDate: Fri Dec 30 20:22:21 2022 +0800 [SPARK-41440][CONNECT][PYTHON] Avoid the cache operator for general Sample ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/39017 supported `DataFrame.randomSplit`. But cache the Sample plan incorrectly. ### Why are the changes needed? This PR avoid the cache operator for general `Sample`. This PR also give a more suitable name `deterministic_order` to replace `force_stable_sort`. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? Tests updated. Closes #39240 from beliefer/SPARK-41440_followup2. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 7 +- .../org/apache/spark/sql/connect/dsl/package.scala | 2 +- .../sql/connect/planner/SparkConnectPlanner.scala | 9 ++- python/pyspark/sql/connect/dataframe.py | 2 +- python/pyspark/sql/connect/plan.py | 6 +- python/pyspark/sql/connect/proto/relations_pb2.py | 88 +++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 25 ++---- .../sql/tests/connect/test_connect_plan_only.py | 10 +-- 8 files changed, 71 insertions(+), 78 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index afff04f8f0d..3bb0b362b27 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -358,9 +358,10 @@ message Sample { // (Optional) The random seed. optional int64 seed = 5; - // (Optional) Explicitly sort the underlying plan to make the ordering deterministic. - // This flag is only used to randomly splits DataFrame with the provided weights. - optional bool force_stable_sort = 6; + // (Required) Explicitly sort the underlying plan to make the ordering deterministic or cache it. + // This flag is true when invoking `dataframe.randomSplit` to randomly splits DataFrame with the + // provided weights. Otherwise, it is false. + bool deterministic_order = 6; } // Relation of type [[Range]] that generates a sequence of integers. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 3bd713a9710..c4a5eac46c0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -964,7 +964,7 @@ package object dsl { .setUpperBound(x(1)) .setWithReplacement(false) .setSeed(seed) - .setForceStableSort(true) + .setDeterministicOrder(true) .build()) .build() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7d6fdc2883e..a11ebd8b7d1 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -211,8 +211,9 @@ class SparkConnectPlanner(session: SparkSession) { * wrap such fields into proto messages. */ private def transformSample(rel: proto.Sample): LogicalPlan = { - val input = Dataset.ofRows(session, transformRelation(rel.getInput)) - val plan = if (rel.getForceStableSort) { + val plan = if (rel.getDeterministicOrder) { + val input = Dataset.ofRows(session, transformRelation(rel.getInput)) + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the @@ -224,11 +225,11 @@ class SparkConnectPlanner(session: SparkSession) { if (sortOrder.nonEmpty) { Sort(sortOrder, global = false, input.logicalPlan) } else { + input.cache() input.logicalPlan } } else { - input.cache() - input.logicalPlan + transformRelation(rel.getInput) } Sample( diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 256e63122ab..018785b77b0 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -523,7 +523,7 @@ class DataFrame: upper_bound=upperBound, with_replacement=False, seed=int(seed), - force_stable_sort=True, + deterministic_order=True, ), session=self._session, ) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index e1b9fa0d0e4..069a1329e80 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -558,14 +558,14 @@ class Sample(LogicalPlan): upper_bound: float, with_replacement: bool, seed: Optional[int], - force_stable_sort: bool = False, + deterministic_order: bool = False, ) -> None: super().__init__(child) self.lower_bound = lower_bound self.upper_bound = upper_bound self.with_replacement = with_replacement self.seed = seed - self.force_stable_sort = force_stable_sort + self.deterministic_order = deterministic_order def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None @@ -576,7 +576,7 @@ class Sample(LogicalPlan): plan.sample.with_replacement = self.with_replacement if self.seed is not None: plan.sample.seed = self.seed - plan.sample.force_stable_sort = self.force_stable_sort + plan.sample.deterministic_order = self.deterministic_order return plan diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 92a6d252a68..205caca07b1 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xa4\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xa4\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -612,47 +612,47 @@ if _descriptor._USE_C_DESCRIPTORS == False: _LOCALRELATION._serialized_start = 5266 _LOCALRELATION._serialized_end = 5403 _SAMPLE._serialized_start = 5406 - _SAMPLE._serialized_end = 5701 - _RANGE._serialized_start = 5704 - _RANGE._serialized_end = 5849 - _SUBQUERYALIAS._serialized_start = 5851 - _SUBQUERYALIAS._serialized_end = 5965 - _REPARTITION._serialized_start = 5968 - _REPARTITION._serialized_end = 6110 - _SHOWSTRING._serialized_start = 6113 - _SHOWSTRING._serialized_end = 6255 - _STATSUMMARY._serialized_start = 6257 - _STATSUMMARY._serialized_end = 6349 - _STATDESCRIBE._serialized_start = 6351 - _STATDESCRIBE._serialized_end = 6432 - _STATCROSSTAB._serialized_start = 6434 - _STATCROSSTAB._serialized_end = 6535 - _STATCOV._serialized_start = 6537 - _STATCOV._serialized_end = 6633 - _STATCORR._serialized_start = 6636 - _STATCORR._serialized_end = 6773 - _NAFILL._serialized_start = 6776 - _NAFILL._serialized_end = 6910 - _NADROP._serialized_start = 6913 - _NADROP._serialized_end = 7047 - _NAREPLACE._serialized_start = 7050 - _NAREPLACE._serialized_end = 7346 - _NAREPLACE_REPLACEMENT._serialized_start = 7205 - _NAREPLACE_REPLACEMENT._serialized_end = 7346 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7348 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7462 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7465 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7724 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 7657 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7724 - _WITHCOLUMNS._serialized_start = 7727 - _WITHCOLUMNS._serialized_end = 7858 - _HINT._serialized_start = 7861 - _HINT._serialized_end = 8001 - _UNPIVOT._serialized_start = 8004 - _UNPIVOT._serialized_end = 8250 - _TOSCHEMA._serialized_start = 8252 - _TOSCHEMA._serialized_end = 8358 - _REPARTITIONBYEXPRESSION._serialized_start = 8361 - _REPARTITIONBYEXPRESSION._serialized_end = 8564 + _SAMPLE._serialized_end = 5679 + _RANGE._serialized_start = 5682 + _RANGE._serialized_end = 5827 + _SUBQUERYALIAS._serialized_start = 5829 + _SUBQUERYALIAS._serialized_end = 5943 + _REPARTITION._serialized_start = 5946 + _REPARTITION._serialized_end = 6088 + _SHOWSTRING._serialized_start = 6091 + _SHOWSTRING._serialized_end = 6233 + _STATSUMMARY._serialized_start = 6235 + _STATSUMMARY._serialized_end = 6327 + _STATDESCRIBE._serialized_start = 6329 + _STATDESCRIBE._serialized_end = 6410 + _STATCROSSTAB._serialized_start = 6412 + _STATCROSSTAB._serialized_end = 6513 + _STATCOV._serialized_start = 6515 + _STATCOV._serialized_end = 6611 + _STATCORR._serialized_start = 6614 + _STATCORR._serialized_end = 6751 + _NAFILL._serialized_start = 6754 + _NAFILL._serialized_end = 6888 + _NADROP._serialized_start = 6891 + _NADROP._serialized_end = 7025 + _NAREPLACE._serialized_start = 7028 + _NAREPLACE._serialized_end = 7324 + _NAREPLACE_REPLACEMENT._serialized_start = 7183 + _NAREPLACE_REPLACEMENT._serialized_end = 7324 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7326 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7440 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7443 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7702 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 7635 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7702 + _WITHCOLUMNS._serialized_start = 7705 + _WITHCOLUMNS._serialized_end = 7836 + _HINT._serialized_start = 7839 + _HINT._serialized_end = 7979 + _UNPIVOT._serialized_start = 7982 + _UNPIVOT._serialized_end = 8228 + _TOSCHEMA._serialized_start = 8230 + _TOSCHEMA._serialized_end = 8336 + _REPARTITIONBYEXPRESSION._serialized_start = 8339 + _REPARTITIONBYEXPRESSION._serialized_end = 8542 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 36b55ec4f6b..e83e64f64e8 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1293,7 +1293,7 @@ class Sample(google.protobuf.message.Message): UPPER_BOUND_FIELD_NUMBER: builtins.int WITH_REPLACEMENT_FIELD_NUMBER: builtins.int SEED_FIELD_NUMBER: builtins.int - FORCE_STABLE_SORT_FIELD_NUMBER: builtins.int + DETERMINISTIC_ORDER_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: """(Required) Input relation for a Sample.""" @@ -1305,9 +1305,10 @@ class Sample(google.protobuf.message.Message): """(Optional) Whether to sample with replacement.""" seed: builtins.int """(Optional) The random seed.""" - force_stable_sort: builtins.bool - """(Optional) Explicitly sort the underlying plan to make the ordering deterministic. - This flag is only used to randomly splits DataFrame with the provided weights. + deterministic_order: builtins.bool + """(Required) Explicitly sort the underlying plan to make the ordering deterministic or cache it. + This flag is true when invoking `dataframe.randomSplit` to randomly splits DataFrame with the + provided weights. Otherwise, it is false. """ def __init__( self, @@ -1317,19 +1318,15 @@ class Sample(google.protobuf.message.Message): upper_bound: builtins.float = ..., with_replacement: builtins.bool | None = ..., seed: builtins.int | None = ..., - force_stable_sort: builtins.bool | None = ..., + deterministic_order: builtins.bool = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "_force_stable_sort", - b"_force_stable_sort", "_seed", b"_seed", "_with_replacement", b"_with_replacement", - "force_stable_sort", - b"force_stable_sort", "input", b"input", "seed", @@ -1341,14 +1338,12 @@ class Sample(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ - "_force_stable_sort", - b"_force_stable_sort", "_seed", b"_seed", "_with_replacement", b"_with_replacement", - "force_stable_sort", - b"force_stable_sort", + "deterministic_order", + b"deterministic_order", "input", b"input", "lower_bound", @@ -1362,10 +1357,6 @@ class Sample(google.protobuf.message.Message): ], ) -> None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_force_stable_sort", b"_force_stable_sort"] - ) -> typing_extensions.Literal["force_stable_sort"] | None: ... - @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_seed", b"_seed"] ) -> typing_extensions.Literal["seed"] | None: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 240fd7d4d72..529e3ca3eda 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -251,21 +251,21 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.sample.upper_bound, 0.16666666666666666) self.assertEqual(plan.root.sample.with_replacement, False) self.assertEqual(plan.root.sample.HasField("seed"), True) - self.assertEqual(plan.root.sample.force_stable_sort, True) + self.assertEqual(plan.root.sample.deterministic_order, True) plan = relations[1]._plan.to_proto(self.connect) self.assertEqual(plan.root.sample.lower_bound, 0.16666666666666666) self.assertEqual(plan.root.sample.upper_bound, 0.5) self.assertEqual(plan.root.sample.with_replacement, False) self.assertEqual(plan.root.sample.HasField("seed"), True) - self.assertEqual(plan.root.sample.force_stable_sort, True) + self.assertEqual(plan.root.sample.deterministic_order, True) plan = relations[2]._plan.to_proto(self.connect) self.assertEqual(plan.root.sample.lower_bound, 0.5) self.assertEqual(plan.root.sample.upper_bound, 1.0) self.assertEqual(plan.root.sample.with_replacement, False) self.assertEqual(plan.root.sample.HasField("seed"), True) - self.assertEqual(plan.root.sample.force_stable_sort, True) + self.assertEqual(plan.root.sample.deterministic_order, True) relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0], 1) checkRelations(relations) @@ -326,7 +326,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.sample.upper_bound, 0.3) self.assertEqual(plan.root.sample.with_replacement, False) self.assertEqual(plan.root.sample.HasField("seed"), False) - self.assertEqual(plan.root.sample.force_stable_sort, False) + self.assertEqual(plan.root.sample.deterministic_order, False) plan = ( df.filter(df.col_name > 3) @@ -337,7 +337,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.sample.upper_bound, 0.4) self.assertEqual(plan.root.sample.with_replacement, True) self.assertEqual(plan.root.sample.seed, -1) - self.assertEqual(plan.root.sample.force_stable_sort, False) + self.assertEqual(plan.root.sample.deterministic_order, False) def test_sort(self): df = self.connect.readTable(table_name=self.tbl_name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org