This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7934f00d124 [SPARK-40839][CONNECT][PYTHON] Implement `DataFrame.sample`
7934f00d124 is described below

commit 7934f00d1241431dd59207650693aaad1a319a70
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Oct 21 17:18:34 2022 +0900

    [SPARK-40839][CONNECT][PYTHON] Implement `DataFrame.sample`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.sample` in Connect
    
    ### Why are the changes needed?
    for DataFrame API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, new API
    
    ```
        def sample(
            self,
            fraction: float,
            *,
            withReplacement: bool = False,
            seed: Optional[int] = None,
        ) -> "DataFrame":
    ```
    
    ### How was this patch tested?
    added UT
    
    Closes #38310 from zhengruifeng/connect_df_sample.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  6 ++-
 .../org/apache/spark/sql/connect/dsl/package.scala |  3 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  5 ++-
 python/pyspark/sql/connect/dataframe.py            | 27 ++++++++++++
 python/pyspark/sql/connect/plan.py                 | 50 ++++++++++++++++++++++
 python/pyspark/sql/connect/proto/relations_pb2.py  |  6 ++-
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 19 ++++++--
 .../sql/tests/connect/test_connect_plan_only.py    | 18 ++++++++
 8 files changed, 125 insertions(+), 9 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 6adf0831ea2..7dbde775ee8 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -201,5 +201,9 @@ message Sample {
   double lower_bound = 2;
   double upper_bound = 3;
   bool with_replacement = 4;
-  int64 seed = 5;
+  Seed seed = 5;
+
+  message Seed {
+    int64 seed = 1;
+  }
 }
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 68bbc0487f9..4630c86049c 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -272,7 +272,8 @@ package object dsl {
               .setUpperBound(upperBound)
               .setLowerBound(lowerBound)
               .setWithReplacement(withReplacement)
-              .setSeed(seed))
+              .setSeed(proto.Sample.Seed.newBuilder().setSeed(seed).build())
+              .build())
           .build()
       }
 
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 92c8bf01cba..880618cc333 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -31,6 +31,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sa
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 final case class InvalidPlanInput(
     private val message: String = "",
@@ -80,7 +81,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
 
   /**
    * All fields of [[proto.Sample]] are optional. However, given those are 
proto primitive types,
-   * we cannot differentiate if the fied is not or set when the field's value 
equals to the type
+   * we cannot differentiate if the field is not or set when the field's value 
equals to the type
    * default value. In the future if this ever become a problem, one solution 
could be that to
    * wrap such fields into proto messages.
    */
@@ -89,7 +90,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
       rel.getLowerBound,
       rel.getUpperBound,
       rel.getWithReplacement,
-      rel.getSeed,
+      if (rel.hasSeed) rel.getSeed.getSeed else Utils.random.nextLong,
       transformRelation(rel.getInput))
   }
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 5ca747fdd6a..eabcf433ae9 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -206,6 +206,33 @@ class DataFrame(object):
         """Sort by a specific column"""
         return DataFrame.withPlan(plan.Sort(self._plan, *cols), 
session=self._session)
 
+    def sample(
+        self,
+        fraction: float,
+        *,
+        withReplacement: bool = False,
+        seed: Optional[int] = None,
+    ) -> "DataFrame":
+        if not isinstance(fraction, float):
+            raise TypeError(f"'fraction' must be float, but got 
{type(fraction).__name__}")
+        if not isinstance(withReplacement, bool):
+            raise TypeError(
+                f"'withReplacement' must be bool, but got 
{type(withReplacement).__name__}"
+            )
+        if seed is not None and not isinstance(seed, int):
+            raise TypeError(f"'seed' must be None or int, but got 
{type(seed).__name__}")
+
+        return DataFrame.withPlan(
+            plan.Sample(
+                child=self._plan,
+                lower_bound=0.0,
+                upper_bound=fraction,
+                with_replacement=withReplacement,
+                seed=seed,
+            ),
+            session=self._session,
+        )
+
     def show(self, n: int, truncate: Optional[Union[bool, int]], vertical: 
Optional[bool]) -> None:
         ...
 
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 5b8b7c71866..297b15994d3 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -385,6 +385,56 @@ class Sort(LogicalPlan):
         """
 
 
+class Sample(LogicalPlan):
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        lower_bound: float,
+        upper_bound: float,
+        with_replacement: bool,
+        seed: Optional[int],
+    ) -> None:
+        super().__init__(child)
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        self.with_replacement = with_replacement
+        self.seed = seed
+
+    def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
+        assert self._child is not None
+        plan = proto.Relation()
+        plan.sample.input.CopyFrom(self._child.plan(session))
+        plan.sample.lower_bound = self.lower_bound
+        plan.sample.upper_bound = self.upper_bound
+        plan.sample.with_replacement = self.with_replacement
+        if self.seed is not None:
+            plan.sample.seed.seed = self.seed
+        return plan
+
+    def print(self, indent: int = 0) -> str:
+        c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child 
else ""
+        return (
+            f"{' ' * indent}"
+            f"<Sample lowerBound={self.lower_bound}, 
upperBound={self.upper_bound}, "
+            f"withReplacement={self.with_replacement}, seed={self.seed}>"
+            f"\n{c_buf}"
+        )
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+            <li>
+                <b>Sample</b><br />
+                LowerBound: {self.lower_bound} <br />
+                UpperBound: {self.upper_bound} <br />
+                WithReplacement: {self.with_replacement} <br />
+                Seed: {self.seed} <br />
+                {self._child_repr_()}
+            </li>
+        </uL>
+        """
+
+
 class Aggregate(LogicalPlan):
     MeasureType = Tuple["ExpressionOrString", str]
     MeasuresType = Sequence[MeasureType]
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 1c868bcf411..d9a596fba8c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -92,5 +92,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _LOCALRELATION._serialized_start = 3387
     _LOCALRELATION._serialized_end = 3480
     _SAMPLE._serialized_start = 3483
-    _SAMPLE._serialized_end = 3667
+    _SAMPLE._serialized_end = 3723
+    _SAMPLE_SEED._serialized_start = 3697
+    _SAMPLE_SEED._serialized_end = 3723
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index fc135c559a6..df179df1480 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -839,6 +839,18 @@ class Sample(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    class Seed(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        SEED_FIELD_NUMBER: builtins.int
+        seed: builtins.int
+        def __init__(
+            self,
+            *,
+            seed: builtins.int = ...,
+        ) -> None: ...
+        def ClearField(self, field_name: typing_extensions.Literal["seed", 
b"seed"]) -> None: ...
+
     INPUT_FIELD_NUMBER: builtins.int
     LOWER_BOUND_FIELD_NUMBER: builtins.int
     UPPER_BOUND_FIELD_NUMBER: builtins.int
@@ -849,7 +861,8 @@ class Sample(google.protobuf.message.Message):
     lower_bound: builtins.float
     upper_bound: builtins.float
     with_replacement: builtins.bool
-    seed: builtins.int
+    @property
+    def seed(self) -> global___Sample.Seed: ...
     def __init__(
         self,
         *,
@@ -857,10 +870,10 @@ class Sample(google.protobuf.message.Message):
         lower_bound: builtins.float = ...,
         upper_bound: builtins.float = ...,
         with_replacement: builtins.bool = ...,
-        seed: builtins.int = ...,
+        seed: global___Sample.Seed | None = ...,
     ) -> None: ...
     def HasField(
-        self, field_name: typing_extensions.Literal["input", b"input"]
+        self, field_name: typing_extensions.Literal["input", b"input", "seed", 
b"seed"]
     ) -> builtins.bool: ...
     def ClearField(
         self,
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 739c24ca96e..3b609db7a02 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -54,6 +54,24 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         offset_plan = df.offset(10)._plan.to_proto(self.connect)
         self.assertEqual(offset_plan.root.offset.offset, 10)
 
+    def test_sample(self):
+        df = self.connect.readTable(table_name=self.tbl_name)
+        plan = df.filter(df.col_name > 
3).sample(fraction=0.3)._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.sample.lower_bound, 0.0)
+        self.assertEqual(plan.root.sample.upper_bound, 0.3)
+        self.assertEqual(plan.root.sample.with_replacement, False)
+        self.assertEqual(plan.root.sample.HasField("seed"), False)
+
+        plan = (
+            df.filter(df.col_name > 3)
+            .sample(withReplacement=True, fraction=0.4, seed=-1)
+            ._plan.to_proto(self.connect)
+        )
+        self.assertEqual(plan.root.sample.lower_bound, 0.0)
+        self.assertEqual(plan.root.sample.upper_bound, 0.4)
+        self.assertEqual(plan.root.sample.with_replacement, True)
+        self.assertEqual(plan.root.sample.seed.seed, -1)
+
     def test_relation_alias(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         plan = df.alias("table_alias")._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to