This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7143cbd0725 [SPARK-41440][CONNECT][PYTHON] Avoid the cache operator 
for general Sample
7143cbd0725 is described below

commit 7143cbd072557c8ea231b378572e8a7554d8a3f5
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Fri Dec 30 20:22:21 2022 +0800

    [SPARK-41440][CONNECT][PYTHON] Avoid the cache operator for general Sample
    
    ### What changes were proposed in this pull request?
    https://github.com/apache/spark/pull/39017 supported 
`DataFrame.randomSplit`. But cache the Sample plan incorrectly.
    
    ### Why are the changes needed?
    This PR avoid the cache operator for general `Sample`.
    This PR also give a more suitable name `deterministic_order` to replace 
`force_stable_sort`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    Tests updated.
    
    Closes #39240 from beliefer/SPARK-41440_followup2.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  7 +-
 .../org/apache/spark/sql/connect/dsl/package.scala |  2 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  9 ++-
 python/pyspark/sql/connect/dataframe.py            |  2 +-
 python/pyspark/sql/connect/plan.py                 |  6 +-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 88 +++++++++++-----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 25 ++----
 .../sql/tests/connect/test_connect_plan_only.py    | 10 +--
 8 files changed, 71 insertions(+), 78 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index afff04f8f0d..3bb0b362b27 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -358,9 +358,10 @@ message Sample {
   // (Optional) The random seed.
   optional int64 seed = 5;
 
-  // (Optional) Explicitly sort the underlying plan to make the ordering 
deterministic.
-  // This flag is only used to randomly splits DataFrame with the provided 
weights.
-  optional bool force_stable_sort = 6;
+  // (Required) Explicitly sort the underlying plan to make the ordering 
deterministic or cache it.
+  // This flag is true when invoking `dataframe.randomSplit` to randomly 
splits DataFrame with the
+  // provided weights. Otherwise, it is false.
+  bool deterministic_order = 6;
 }
 
 // Relation of type [[Range]] that generates a sequence of integers.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 3bd713a9710..c4a5eac46c0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -964,7 +964,7 @@ package object dsl {
                   .setUpperBound(x(1))
                   .setWithReplacement(false)
                   .setSeed(seed)
-                  .setForceStableSort(true)
+                  .setDeterministicOrder(true)
                   .build())
               .build()
           }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 7d6fdc2883e..a11ebd8b7d1 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -211,8 +211,9 @@ class SparkConnectPlanner(session: SparkSession) {
    * wrap such fields into proto messages.
    */
   private def transformSample(rel: proto.Sample): LogicalPlan = {
-    val input = Dataset.ofRows(session, transformRelation(rel.getInput))
-    val plan = if (rel.getForceStableSort) {
+    val plan = if (rel.getDeterministicOrder) {
+      val input = Dataset.ofRows(session, transformRelation(rel.getInput))
+
       // It is possible that the underlying dataframe doesn't guarantee the 
ordering of rows in its
       // constituent partitions each time a split is materialized which could 
result in
       // overlapping splits. To prevent this, we explicitly sort each input 
partition to make the
@@ -224,11 +225,11 @@ class SparkConnectPlanner(session: SparkSession) {
       if (sortOrder.nonEmpty) {
         Sort(sortOrder, global = false, input.logicalPlan)
       } else {
+        input.cache()
         input.logicalPlan
       }
     } else {
-      input.cache()
-      input.logicalPlan
+      transformRelation(rel.getInput)
     }
 
     Sample(
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 256e63122ab..018785b77b0 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -523,7 +523,7 @@ class DataFrame:
                     upper_bound=upperBound,
                     with_replacement=False,
                     seed=int(seed),
-                    force_stable_sort=True,
+                    deterministic_order=True,
                 ),
                 session=self._session,
             )
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index e1b9fa0d0e4..069a1329e80 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -558,14 +558,14 @@ class Sample(LogicalPlan):
         upper_bound: float,
         with_replacement: bool,
         seed: Optional[int],
-        force_stable_sort: bool = False,
+        deterministic_order: bool = False,
     ) -> None:
         super().__init__(child)
         self.lower_bound = lower_bound
         self.upper_bound = upper_bound
         self.with_replacement = with_replacement
         self.seed = seed
-        self.force_stable_sort = force_stable_sort
+        self.deterministic_order = deterministic_order
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
@@ -576,7 +576,7 @@ class Sample(LogicalPlan):
         plan.sample.with_replacement = self.with_replacement
         if self.seed is not None:
             plan.sample.seed = self.seed
-        plan.sample.force_stable_sort = self.force_stable_sort
+        plan.sample.deterministic_order = self.deterministic_order
         return plan
 
 
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 92a6d252a68..205caca07b1 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xa4\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xa4\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -612,47 +612,47 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _LOCALRELATION._serialized_start = 5266
     _LOCALRELATION._serialized_end = 5403
     _SAMPLE._serialized_start = 5406
-    _SAMPLE._serialized_end = 5701
-    _RANGE._serialized_start = 5704
-    _RANGE._serialized_end = 5849
-    _SUBQUERYALIAS._serialized_start = 5851
-    _SUBQUERYALIAS._serialized_end = 5965
-    _REPARTITION._serialized_start = 5968
-    _REPARTITION._serialized_end = 6110
-    _SHOWSTRING._serialized_start = 6113
-    _SHOWSTRING._serialized_end = 6255
-    _STATSUMMARY._serialized_start = 6257
-    _STATSUMMARY._serialized_end = 6349
-    _STATDESCRIBE._serialized_start = 6351
-    _STATDESCRIBE._serialized_end = 6432
-    _STATCROSSTAB._serialized_start = 6434
-    _STATCROSSTAB._serialized_end = 6535
-    _STATCOV._serialized_start = 6537
-    _STATCOV._serialized_end = 6633
-    _STATCORR._serialized_start = 6636
-    _STATCORR._serialized_end = 6773
-    _NAFILL._serialized_start = 6776
-    _NAFILL._serialized_end = 6910
-    _NADROP._serialized_start = 6913
-    _NADROP._serialized_end = 7047
-    _NAREPLACE._serialized_start = 7050
-    _NAREPLACE._serialized_end = 7346
-    _NAREPLACE_REPLACEMENT._serialized_start = 7205
-    _NAREPLACE_REPLACEMENT._serialized_end = 7346
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7348
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7462
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7465
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7724
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
7657
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7724
-    _WITHCOLUMNS._serialized_start = 7727
-    _WITHCOLUMNS._serialized_end = 7858
-    _HINT._serialized_start = 7861
-    _HINT._serialized_end = 8001
-    _UNPIVOT._serialized_start = 8004
-    _UNPIVOT._serialized_end = 8250
-    _TOSCHEMA._serialized_start = 8252
-    _TOSCHEMA._serialized_end = 8358
-    _REPARTITIONBYEXPRESSION._serialized_start = 8361
-    _REPARTITIONBYEXPRESSION._serialized_end = 8564
+    _SAMPLE._serialized_end = 5679
+    _RANGE._serialized_start = 5682
+    _RANGE._serialized_end = 5827
+    _SUBQUERYALIAS._serialized_start = 5829
+    _SUBQUERYALIAS._serialized_end = 5943
+    _REPARTITION._serialized_start = 5946
+    _REPARTITION._serialized_end = 6088
+    _SHOWSTRING._serialized_start = 6091
+    _SHOWSTRING._serialized_end = 6233
+    _STATSUMMARY._serialized_start = 6235
+    _STATSUMMARY._serialized_end = 6327
+    _STATDESCRIBE._serialized_start = 6329
+    _STATDESCRIBE._serialized_end = 6410
+    _STATCROSSTAB._serialized_start = 6412
+    _STATCROSSTAB._serialized_end = 6513
+    _STATCOV._serialized_start = 6515
+    _STATCOV._serialized_end = 6611
+    _STATCORR._serialized_start = 6614
+    _STATCORR._serialized_end = 6751
+    _NAFILL._serialized_start = 6754
+    _NAFILL._serialized_end = 6888
+    _NADROP._serialized_start = 6891
+    _NADROP._serialized_end = 7025
+    _NAREPLACE._serialized_start = 7028
+    _NAREPLACE._serialized_end = 7324
+    _NAREPLACE_REPLACEMENT._serialized_start = 7183
+    _NAREPLACE_REPLACEMENT._serialized_end = 7324
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7326
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7440
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7443
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7702
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
7635
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7702
+    _WITHCOLUMNS._serialized_start = 7705
+    _WITHCOLUMNS._serialized_end = 7836
+    _HINT._serialized_start = 7839
+    _HINT._serialized_end = 7979
+    _UNPIVOT._serialized_start = 7982
+    _UNPIVOT._serialized_end = 8228
+    _TOSCHEMA._serialized_start = 8230
+    _TOSCHEMA._serialized_end = 8336
+    _REPARTITIONBYEXPRESSION._serialized_start = 8339
+    _REPARTITIONBYEXPRESSION._serialized_end = 8542
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 36b55ec4f6b..e83e64f64e8 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1293,7 +1293,7 @@ class Sample(google.protobuf.message.Message):
     UPPER_BOUND_FIELD_NUMBER: builtins.int
     WITH_REPLACEMENT_FIELD_NUMBER: builtins.int
     SEED_FIELD_NUMBER: builtins.int
-    FORCE_STABLE_SORT_FIELD_NUMBER: builtins.int
+    DETERMINISTIC_ORDER_FIELD_NUMBER: builtins.int
     @property
     def input(self) -> global___Relation:
         """(Required) Input relation for a Sample."""
@@ -1305,9 +1305,10 @@ class Sample(google.protobuf.message.Message):
     """(Optional) Whether to sample with replacement."""
     seed: builtins.int
     """(Optional) The random seed."""
-    force_stable_sort: builtins.bool
-    """(Optional) Explicitly sort the underlying plan to make the ordering 
deterministic.
-    This flag is only used to randomly splits DataFrame with the provided 
weights.
+    deterministic_order: builtins.bool
+    """(Required) Explicitly sort the underlying plan to make the ordering 
deterministic or cache it.
+    This flag is true when invoking `dataframe.randomSplit` to randomly splits 
DataFrame with the
+    provided weights. Otherwise, it is false.
     """
     def __init__(
         self,
@@ -1317,19 +1318,15 @@ class Sample(google.protobuf.message.Message):
         upper_bound: builtins.float = ...,
         with_replacement: builtins.bool | None = ...,
         seed: builtins.int | None = ...,
-        force_stable_sort: builtins.bool | None = ...,
+        deterministic_order: builtins.bool = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
-            "_force_stable_sort",
-            b"_force_stable_sort",
             "_seed",
             b"_seed",
             "_with_replacement",
             b"_with_replacement",
-            "force_stable_sort",
-            b"force_stable_sort",
             "input",
             b"input",
             "seed",
@@ -1341,14 +1338,12 @@ class Sample(google.protobuf.message.Message):
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "_force_stable_sort",
-            b"_force_stable_sort",
             "_seed",
             b"_seed",
             "_with_replacement",
             b"_with_replacement",
-            "force_stable_sort",
-            b"force_stable_sort",
+            "deterministic_order",
+            b"deterministic_order",
             "input",
             b"input",
             "lower_bound",
@@ -1362,10 +1357,6 @@ class Sample(google.protobuf.message.Message):
         ],
     ) -> None: ...
     @typing.overload
-    def WhichOneof(
-        self, oneof_group: typing_extensions.Literal["_force_stable_sort", 
b"_force_stable_sort"]
-    ) -> typing_extensions.Literal["force_stable_sort"] | None: ...
-    @typing.overload
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]
     ) -> typing_extensions.Literal["seed"] | None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 240fd7d4d72..529e3ca3eda 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -251,21 +251,21 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
             self.assertEqual(plan.root.sample.upper_bound, 0.16666666666666666)
             self.assertEqual(plan.root.sample.with_replacement, False)
             self.assertEqual(plan.root.sample.HasField("seed"), True)
-            self.assertEqual(plan.root.sample.force_stable_sort, True)
+            self.assertEqual(plan.root.sample.deterministic_order, True)
 
             plan = relations[1]._plan.to_proto(self.connect)
             self.assertEqual(plan.root.sample.lower_bound, 0.16666666666666666)
             self.assertEqual(plan.root.sample.upper_bound, 0.5)
             self.assertEqual(plan.root.sample.with_replacement, False)
             self.assertEqual(plan.root.sample.HasField("seed"), True)
-            self.assertEqual(plan.root.sample.force_stable_sort, True)
+            self.assertEqual(plan.root.sample.deterministic_order, True)
 
             plan = relations[2]._plan.to_proto(self.connect)
             self.assertEqual(plan.root.sample.lower_bound, 0.5)
             self.assertEqual(plan.root.sample.upper_bound, 1.0)
             self.assertEqual(plan.root.sample.with_replacement, False)
             self.assertEqual(plan.root.sample.HasField("seed"), True)
-            self.assertEqual(plan.root.sample.force_stable_sort, True)
+            self.assertEqual(plan.root.sample.deterministic_order, True)
 
         relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0], 1)
         checkRelations(relations)
@@ -326,7 +326,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.sample.upper_bound, 0.3)
         self.assertEqual(plan.root.sample.with_replacement, False)
         self.assertEqual(plan.root.sample.HasField("seed"), False)
-        self.assertEqual(plan.root.sample.force_stable_sort, False)
+        self.assertEqual(plan.root.sample.deterministic_order, False)
 
         plan = (
             df.filter(df.col_name > 3)
@@ -337,7 +337,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.sample.upper_bound, 0.4)
         self.assertEqual(plan.root.sample.with_replacement, True)
         self.assertEqual(plan.root.sample.seed, -1)
-        self.assertEqual(plan.root.sample.force_stable_sort, False)
+        self.assertEqual(plan.root.sample.deterministic_order, False)
 
     def test_sort(self):
         df = self.connect.readTable(table_name=self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to