This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7934f00d124 [SPARK-40839][CONNECT][PYTHON] Implement `DataFrame.sample` 7934f00d124 is described below commit 7934f00d1241431dd59207650693aaad1a319a70 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Oct 21 17:18:34 2022 +0900 [SPARK-40839][CONNECT][PYTHON] Implement `DataFrame.sample` ### What changes were proposed in this pull request? Implement `DataFrame.sample` in Connect ### Why are the changes needed? for DataFrame API coverage ### Does this PR introduce _any_ user-facing change? Yes, new API ``` def sample( self, fraction: float, *, withReplacement: bool = False, seed: Optional[int] = None, ) -> "DataFrame": ``` ### How was this patch tested? added UT Closes #38310 from zhengruifeng/connect_df_sample. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 6 ++- .../org/apache/spark/sql/connect/dsl/package.scala | 3 +- .../sql/connect/planner/SparkConnectPlanner.scala | 5 ++- python/pyspark/sql/connect/dataframe.py | 27 ++++++++++++ python/pyspark/sql/connect/plan.py | 50 ++++++++++++++++++++++ python/pyspark/sql/connect/proto/relations_pb2.py | 6 ++- python/pyspark/sql/connect/proto/relations_pb2.pyi | 19 ++++++-- .../sql/tests/connect/test_connect_plan_only.py | 18 ++++++++ 8 files changed, 125 insertions(+), 9 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 6adf0831ea2..7dbde775ee8 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -201,5 +201,9 @@ message Sample { double lower_bound = 2; double upper_bound = 3; bool with_replacement = 4; - int64 seed = 5; + Seed seed = 5; + + message Seed { + int64 seed = 1; + } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 68bbc0487f9..4630c86049c 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -272,7 +272,8 @@ package object dsl { .setUpperBound(upperBound) .setLowerBound(lowerBound) .setWithReplacement(withReplacement) - .setSeed(seed)) + .setSeed(proto.Sample.Seed.newBuilder().setSeed(seed).build()) + .build()) .build() } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 92c8bf01cba..880618cc333 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sa import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils final case class InvalidPlanInput( private val message: String = "", @@ -80,7 +81,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { /** * All fields of [[proto.Sample]] are optional. However, given those are proto primitive types, - * we cannot differentiate if the fied is not or set when the field's value equals to the type + * we cannot differentiate if the field is not or set when the field's value equals to the type * default value. In the future if this ever become a problem, one solution could be that to * wrap such fields into proto messages. */ @@ -89,7 +90,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { rel.getLowerBound, rel.getUpperBound, rel.getWithReplacement, - rel.getSeed, + if (rel.hasSeed) rel.getSeed.getSeed else Utils.random.nextLong, transformRelation(rel.getInput)) } diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 5ca747fdd6a..eabcf433ae9 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -206,6 +206,33 @@ class DataFrame(object): """Sort by a specific column""" return DataFrame.withPlan(plan.Sort(self._plan, *cols), session=self._session) + def sample( + self, + fraction: float, + *, + withReplacement: bool = False, + seed: Optional[int] = None, + ) -> "DataFrame": + if not isinstance(fraction, float): + raise TypeError(f"'fraction' must be float, but got {type(fraction).__name__}") + if not isinstance(withReplacement, bool): + raise TypeError( + f"'withReplacement' must be bool, but got {type(withReplacement).__name__}" + ) + if seed is not None and not isinstance(seed, int): + raise TypeError(f"'seed' must be None or int, but got {type(seed).__name__}") + + return DataFrame.withPlan( + plan.Sample( + child=self._plan, + lower_bound=0.0, + upper_bound=fraction, + with_replacement=withReplacement, + seed=seed, + ), + session=self._session, + ) + def show(self, n: int, truncate: Optional[Union[bool, int]], vertical: Optional[bool]) -> None: ... diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 5b8b7c71866..297b15994d3 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -385,6 +385,56 @@ class Sort(LogicalPlan): """ +class Sample(LogicalPlan): + def __init__( + self, + child: Optional["LogicalPlan"], + lower_bound: float, + upper_bound: float, + with_replacement: bool, + seed: Optional[int], + ) -> None: + super().__init__(child) + self.lower_bound = lower_bound + self.upper_bound = upper_bound + self.with_replacement = with_replacement + self.seed = seed + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.sample.input.CopyFrom(self._child.plan(session)) + plan.sample.lower_bound = self.lower_bound + plan.sample.upper_bound = self.upper_bound + plan.sample.with_replacement = self.with_replacement + if self.seed is not None: + plan.sample.seed.seed = self.seed + return plan + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return ( + f"{' ' * indent}" + f"<Sample lowerBound={self.lower_bound}, upperBound={self.upper_bound}, " + f"withReplacement={self.with_replacement}, seed={self.seed}>" + f"\n{c_buf}" + ) + + def _repr_html_(self) -> str: + return f""" + <ul> + <li> + <b>Sample</b><br /> + LowerBound: {self.lower_bound} <br /> + UpperBound: {self.upper_bound} <br /> + WithReplacement: {self.with_replacement} <br /> + Seed: {self.seed} <br /> + {self._child_repr_()} + </li> + </uL> + """ + + class Aggregate(LogicalPlan): MeasureType = Tuple["ExpressionOrString", str] MeasuresType = Sequence[MeasureType] diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 1c868bcf411..d9a596fba8c 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -92,5 +92,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: _LOCALRELATION._serialized_start = 3387 _LOCALRELATION._serialized_end = 3480 _SAMPLE._serialized_start = 3483 - _SAMPLE._serialized_end = 3667 + _SAMPLE._serialized_end = 3723 + _SAMPLE_SEED._serialized_start = 3697 + _SAMPLE_SEED._serialized_end = 3723 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index fc135c559a6..df179df1480 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -839,6 +839,18 @@ class Sample(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class Seed(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SEED_FIELD_NUMBER: builtins.int + seed: builtins.int + def __init__( + self, + *, + seed: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["seed", b"seed"]) -> None: ... + INPUT_FIELD_NUMBER: builtins.int LOWER_BOUND_FIELD_NUMBER: builtins.int UPPER_BOUND_FIELD_NUMBER: builtins.int @@ -849,7 +861,8 @@ class Sample(google.protobuf.message.Message): lower_bound: builtins.float upper_bound: builtins.float with_replacement: builtins.bool - seed: builtins.int + @property + def seed(self) -> global___Sample.Seed: ... def __init__( self, *, @@ -857,10 +870,10 @@ class Sample(google.protobuf.message.Message): lower_bound: builtins.float = ..., upper_bound: builtins.float = ..., with_replacement: builtins.bool = ..., - seed: builtins.int = ..., + seed: global___Sample.Seed | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["input", b"input"] + self, field_name: typing_extensions.Literal["input", b"input", "seed", b"seed"] ) -> builtins.bool: ... def ClearField( self, diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 739c24ca96e..3b609db7a02 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -54,6 +54,24 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): offset_plan = df.offset(10)._plan.to_proto(self.connect) self.assertEqual(offset_plan.root.offset.offset, 10) + def test_sample(self): + df = self.connect.readTable(table_name=self.tbl_name) + plan = df.filter(df.col_name > 3).sample(fraction=0.3)._plan.to_proto(self.connect) + self.assertEqual(plan.root.sample.lower_bound, 0.0) + self.assertEqual(plan.root.sample.upper_bound, 0.3) + self.assertEqual(plan.root.sample.with_replacement, False) + self.assertEqual(plan.root.sample.HasField("seed"), False) + + plan = ( + df.filter(df.col_name > 3) + .sample(withReplacement=True, fraction=0.4, seed=-1) + ._plan.to_proto(self.connect) + ) + self.assertEqual(plan.root.sample.lower_bound, 0.0) + self.assertEqual(plan.root.sample.upper_bound, 0.4) + self.assertEqual(plan.root.sample.with_replacement, True) + self.assertEqual(plan.root.sample.seed.seed, -1) + def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.alias("table_alias")._plan.to_proto(self.connect) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org