This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7934f00d124 [SPARK-40839][CONNECT][PYTHON] Implement `DataFrame.sample`
7934f00d124 is described below
commit 7934f00d1241431dd59207650693aaad1a319a70
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Oct 21 17:18:34 2022 +0900
[SPARK-40839][CONNECT][PYTHON] Implement `DataFrame.sample`
### What changes were proposed in this pull request?
Implement `DataFrame.sample` in Connect
### Why are the changes needed?
for DataFrame API coverage
### Does this PR introduce _any_ user-facing change?
Yes, new API
```
def sample(
self,
fraction: float,
*,
withReplacement: bool = False,
seed: Optional[int] = None,
) -> "DataFrame":
```
### How was this patch tested?
added UT
Closes #38310 from zhengruifeng/connect_df_sample.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 6 ++-
.../org/apache/spark/sql/connect/dsl/package.scala | 3 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 5 ++-
python/pyspark/sql/connect/dataframe.py | 27 ++++++++++++
python/pyspark/sql/connect/plan.py | 50 ++++++++++++++++++++++
python/pyspark/sql/connect/proto/relations_pb2.py | 6 ++-
python/pyspark/sql/connect/proto/relations_pb2.pyi | 19 ++++++--
.../sql/tests/connect/test_connect_plan_only.py | 18 ++++++++
8 files changed, 125 insertions(+), 9 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 6adf0831ea2..7dbde775ee8 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -201,5 +201,9 @@ message Sample {
double lower_bound = 2;
double upper_bound = 3;
bool with_replacement = 4;
- int64 seed = 5;
+ Seed seed = 5;
+
+ message Seed {
+ int64 seed = 1;
+ }
}
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 68bbc0487f9..4630c86049c 100644
---
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -272,7 +272,8 @@ package object dsl {
.setUpperBound(upperBound)
.setLowerBound(lowerBound)
.setWithReplacement(withReplacement)
- .setSeed(seed))
+ .setSeed(proto.Sample.Seed.newBuilder().setSeed(seed).build())
+ .build())
.build()
}
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 92c8bf01cba..880618cc333 100644
---
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -31,6 +31,7 @@ import
org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sa
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
final case class InvalidPlanInput(
private val message: String = "",
@@ -80,7 +81,7 @@ class SparkConnectPlanner(plan: proto.Relation, session:
SparkSession) {
/**
* All fields of [[proto.Sample]] are optional. However, given those are
proto primitive types,
- * we cannot differentiate if the fied is not or set when the field's value
equals to the type
+ * we cannot differentiate if the field is not or set when the field's value
equals to the type
* default value. In the future if this ever become a problem, one solution
could be that to
* wrap such fields into proto messages.
*/
@@ -89,7 +90,7 @@ class SparkConnectPlanner(plan: proto.Relation, session:
SparkSession) {
rel.getLowerBound,
rel.getUpperBound,
rel.getWithReplacement,
- rel.getSeed,
+ if (rel.hasSeed) rel.getSeed.getSeed else Utils.random.nextLong,
transformRelation(rel.getInput))
}
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 5ca747fdd6a..eabcf433ae9 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -206,6 +206,33 @@ class DataFrame(object):
"""Sort by a specific column"""
return DataFrame.withPlan(plan.Sort(self._plan, *cols),
session=self._session)
+ def sample(
+ self,
+ fraction: float,
+ *,
+ withReplacement: bool = False,
+ seed: Optional[int] = None,
+ ) -> "DataFrame":
+ if not isinstance(fraction, float):
+ raise TypeError(f"'fraction' must be float, but got
{type(fraction).__name__}")
+ if not isinstance(withReplacement, bool):
+ raise TypeError(
+ f"'withReplacement' must be bool, but got
{type(withReplacement).__name__}"
+ )
+ if seed is not None and not isinstance(seed, int):
+ raise TypeError(f"'seed' must be None or int, but got
{type(seed).__name__}")
+
+ return DataFrame.withPlan(
+ plan.Sample(
+ child=self._plan,
+ lower_bound=0.0,
+ upper_bound=fraction,
+ with_replacement=withReplacement,
+ seed=seed,
+ ),
+ session=self._session,
+ )
+
def show(self, n: int, truncate: Optional[Union[bool, int]], vertical:
Optional[bool]) -> None:
...
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 5b8b7c71866..297b15994d3 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -385,6 +385,56 @@ class Sort(LogicalPlan):
"""
+class Sample(LogicalPlan):
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ lower_bound: float,
+ upper_bound: float,
+ with_replacement: bool,
+ seed: Optional[int],
+ ) -> None:
+ super().__init__(child)
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+ self.with_replacement = with_replacement
+ self.seed = seed
+
+ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
+ assert self._child is not None
+ plan = proto.Relation()
+ plan.sample.input.CopyFrom(self._child.plan(session))
+ plan.sample.lower_bound = self.lower_bound
+ plan.sample.upper_bound = self.upper_bound
+ plan.sample.with_replacement = self.with_replacement
+ if self.seed is not None:
+ plan.sample.seed.seed = self.seed
+ return plan
+
+ def print(self, indent: int = 0) -> str:
+ c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child
else ""
+ return (
+ f"{' ' * indent}"
+ f"<Sample lowerBound={self.lower_bound},
upperBound={self.upper_bound}, "
+ f"withReplacement={self.with_replacement}, seed={self.seed}>"
+ f"\n{c_buf}"
+ )
+
+ def _repr_html_(self) -> str:
+ return f"""
+ <ul>
+ <li>
+ <b>Sample</b><br />
+ LowerBound: {self.lower_bound} <br />
+ UpperBound: {self.upper_bound} <br />
+ WithReplacement: {self.with_replacement} <br />
+ Seed: {self.seed} <br />
+ {self._child_repr_()}
+ </li>
+ </uL>
+ """
+
+
class Aggregate(LogicalPlan):
MeasureType = Tuple["ExpressionOrString", str]
MeasuresType = Sequence[MeasureType]
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 1c868bcf411..d9a596fba8c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as
spark_dot_connect_dot_e
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
\x01(\x0 [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
\x01(\x0 [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -92,5 +92,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_LOCALRELATION._serialized_start = 3387
_LOCALRELATION._serialized_end = 3480
_SAMPLE._serialized_start = 3483
- _SAMPLE._serialized_end = 3667
+ _SAMPLE._serialized_end = 3723
+ _SAMPLE_SEED._serialized_start = 3697
+ _SAMPLE_SEED._serialized_end = 3723
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index fc135c559a6..df179df1480 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -839,6 +839,18 @@ class Sample(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ class Seed(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SEED_FIELD_NUMBER: builtins.int
+ seed: builtins.int
+ def __init__(
+ self,
+ *,
+ seed: builtins.int = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal["seed",
b"seed"]) -> None: ...
+
INPUT_FIELD_NUMBER: builtins.int
LOWER_BOUND_FIELD_NUMBER: builtins.int
UPPER_BOUND_FIELD_NUMBER: builtins.int
@@ -849,7 +861,8 @@ class Sample(google.protobuf.message.Message):
lower_bound: builtins.float
upper_bound: builtins.float
with_replacement: builtins.bool
- seed: builtins.int
+ @property
+ def seed(self) -> global___Sample.Seed: ...
def __init__(
self,
*,
@@ -857,10 +870,10 @@ class Sample(google.protobuf.message.Message):
lower_bound: builtins.float = ...,
upper_bound: builtins.float = ...,
with_replacement: builtins.bool = ...,
- seed: builtins.int = ...,
+ seed: global___Sample.Seed | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["input", b"input"]
+ self, field_name: typing_extensions.Literal["input", b"input", "seed",
b"seed"]
) -> builtins.bool: ...
def ClearField(
self,
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 739c24ca96e..3b609db7a02 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -54,6 +54,24 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
offset_plan = df.offset(10)._plan.to_proto(self.connect)
self.assertEqual(offset_plan.root.offset.offset, 10)
+ def test_sample(self):
+ df = self.connect.readTable(table_name=self.tbl_name)
+ plan = df.filter(df.col_name >
3).sample(fraction=0.3)._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.sample.lower_bound, 0.0)
+ self.assertEqual(plan.root.sample.upper_bound, 0.3)
+ self.assertEqual(plan.root.sample.with_replacement, False)
+ self.assertEqual(plan.root.sample.HasField("seed"), False)
+
+ plan = (
+ df.filter(df.col_name > 3)
+ .sample(withReplacement=True, fraction=0.4, seed=-1)
+ ._plan.to_proto(self.connect)
+ )
+ self.assertEqual(plan.root.sample.lower_bound, 0.0)
+ self.assertEqual(plan.root.sample.upper_bound, 0.4)
+ self.assertEqual(plan.root.sample.with_replacement, True)
+ self.assertEqual(plan.root.sample.seed.seed, -1)
+
def test_relation_alias(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.alias("table_alias")._plan.to_proto(self.connect)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]