This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d9c604ec932 [SPARK-41066][CONNECT][PYTHON] Implement
`DataFrame.sampleBy ` and `DataFrame.stat.sampleBy `
d9c604ec932 is described below
commit d9c604ec9322117fce0c9b3302c3cd73f5d16df7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 2 09:31:21 2023 +0900
[SPARK-41066][CONNECT][PYTHON] Implement `DataFrame.sampleBy ` and
`DataFrame.stat.sampleBy `
### What changes were proposed in this pull request?
Implement `DataFrame.sampleBy ` and `DataFrame.stat.sampleBy `
### Why are the changes needed?
For API coverage
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added UT
Closes #39328 from zhengruifeng/connect_df_sampleby.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 29 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 24 ++-
python/pyspark/sql/connect/dataframe.py | 27 +++
python/pyspark/sql/connect/plan.py | 44 +++++
python/pyspark/sql/connect/proto/relations_pb2.py | 219 ++++++++++++---------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 97 +++++++++
.../sql/tests/connect/test_connect_basic.py | 28 +++
7 files changed, 371 insertions(+), 97 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index db3565eda61..2d834f3fd8c 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -74,6 +74,7 @@ message Relation {
StatCorr corr = 104;
StatApproxQuantile approx_quantile = 105;
StatFreqItems freq_items = 106;
+ StatSampleBy sample_by = 107;
// Catalog API (experimental / unstable)
Catalog catalog = 200;
@@ -546,6 +547,34 @@ message StatFreqItems {
optional double support = 3;
}
+
+// Returns a stratified sample without replacement based on the fraction
+// given on each stratum.
+message StatSampleBy {
+ // (Required) The input relation.
+ Relation input = 1;
+
+ // (Required) The column that defines strata.
+ Expression col = 2;
+
+ // (Required) Sampling fraction for each stratum.
+ //
+ // If a stratum is not specified, we treat its fraction as zero.
+ repeated Fraction fractions = 3;
+
+ // (Optional) The random seed.
+ optional int64 seed = 5;
+
+ message Fraction {
+ // (Required) The stratum.
+ Expression.Literal stratum = 1;
+
+ // (Required) The fraction value. Must be in [0, 1].
+ double fraction = 2;
+ }
+}
+
+
// Replaces null values.
// It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to
compute the results.
// Following 3 parameter combinations are supported:
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dcfdc3f8b52..d7e2908a1c5 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -98,6 +98,8 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
case proto.Relation.RelTypeCase.FREQ_ITEMS =>
transformStatFreqItems(rel.getFreqItems)
+ case proto.Relation.RelTypeCase.SAMPLE_BY =>
+ transformStatSampleBy(rel.getSampleBy)
case proto.Relation.RelTypeCase.TO_SCHEMA =>
transformToSchema(rel.getToSchema)
case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_SAME_LENGTH_NAMES =>
transformRenameColumnsBySamelenghtNames(rel.getRenameColumnsBySameLengthNames)
@@ -419,6 +421,26 @@ class SparkConnectPlanner(session: SparkSession) {
}
}
+ private def transformStatSampleBy(rel: proto.StatSampleBy): LogicalPlan = {
+ val fractions = mutable.Map.empty[Any, Double]
+ rel.getFractionsList.asScala.toSeq.foreach { protoFraction =>
+ val stratum = transformLiteral(protoFraction.getStratum) match {
+ case Literal(s, StringType) if s != null => s.toString
+ case literal => literal.value
+ }
+ fractions.update(stratum, protoFraction.getFraction)
+ }
+
+ Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .stat
+ .sampleBy(
+ col = Column(transformExpression(rel.getCol)),
+ fractions = fractions.toMap,
+ seed = if (rel.hasSeed) rel.getSeed else Utils.random.nextLong)
+ .logicalPlan
+ }
+
private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
val schema = DataTypeProtoConverter.toCatalystType(rel.getSchema)
assert(schema.isInstanceOf[StructType])
@@ -697,7 +719,7 @@ class SparkConnectPlanner(session: SparkSession) {
* @return
* Expression
*/
- private def transformLiteral(lit: proto.Expression.Literal): Expression = {
+ private def transformLiteral(lit: proto.Expression.Literal): Literal = {
toCatalystExpression(lit)
}
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index c5ab22b34bd..95582e86390 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -952,6 +952,26 @@ class DataFrame:
freqItems.__doc__ = PySparkDataFrame.freqItems.__doc__
+ def sampleBy(
+ self, col: "ColumnOrName", fractions: Dict[Any, float], seed:
Optional[int] = None
+ ) -> "DataFrame":
+ if not isinstance(col, (Column, str)):
+ raise TypeError("col must be a string or a column, but got %r" %
type(col))
+ if not isinstance(fractions, dict):
+ raise TypeError("fractions must be a dict but got %r" %
type(fractions))
+ for k, v in fractions.items():
+ if not isinstance(k, (float, int, str)):
+ raise TypeError("key must be float, int, or string, but got
%r" % type(k))
+ fractions[k] = float(v)
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+
+ return DataFrame.withPlan(
+ plan.StatSampleBy(child=self._plan, col=col, fractions=fractions,
seed=seed),
+ session=self._session,
+ )
+
+ sampleBy.__doc__ = PySparkDataFrame.sampleBy.__doc__
+
def _get_alias(self) -> Optional[str]:
p = self._plan
while p is not None:
@@ -1344,5 +1364,12 @@ class DataFrameStatFunctions:
freqItems.__doc__ = DataFrame.freqItems.__doc__
+ def sampleBy(
+ self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None
+ ) -> DataFrame:
+ return self.df.sampleBy(col, fractions, seed)
+
+ sampleBy.__doc__ = DataFrame.sampleBy.__doc__
+
DataFrameStatFunctions.__doc__ = PySparkDataFrameStatFunctions.__doc__
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index f567d88137a..f10687cc82e 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1140,6 +1140,50 @@ class StatFreqItems(LogicalPlan):
return plan
+class StatSampleBy(LogicalPlan):
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ col: "ColumnOrName",
+ fractions: Dict[Any, float],
+ seed: Optional[int],
+ ) -> None:
+ super().__init__(child)
+
+ assert col is not None and isinstance(col, (Column, str))
+
+ assert fractions is not None and isinstance(fractions, dict)
+ for k, v in fractions.items():
+ assert v is not None and isinstance(v, float)
+
+ assert seed is None or isinstance(seed, int)
+
+ if isinstance(col, Column):
+ self._col = col
+ else:
+ self._col = Column(ColumnReference(col))
+
+ self._fractions = fractions
+
+ self._seed = seed
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+
+ plan = proto.Relation()
+ plan.sample_by.input.CopyFrom(self._child.plan(session))
+ plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session))
+ if len(self._fractions) > 0:
+ for k, v in self._fractions.items():
+ fraction = proto.StatSampleBy.Fraction()
+
fraction.stratum.CopyFrom(LiteralExpression._from_value(k).to_plan(session).literal)
+ fraction.fraction = float(v)
+ plan.sample_by.fractions.append(fraction)
+ if self._seed is not None:
+ plan.sample_by.seed = self._seed
+ return plan
+
+
class StatCorr(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str,
method: str) -> None:
super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 6e2904b0294..cf0f2eb3513 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -73,6 +73,8 @@ _STATCOV = DESCRIPTOR.message_types_by_name["StatCov"]
_STATCORR = DESCRIPTOR.message_types_by_name["StatCorr"]
_STATAPPROXQUANTILE = DESCRIPTOR.message_types_by_name["StatApproxQuantile"]
_STATFREQITEMS = DESCRIPTOR.message_types_by_name["StatFreqItems"]
+_STATSAMPLEBY = DESCRIPTOR.message_types_by_name["StatSampleBy"]
+_STATSAMPLEBY_FRACTION = _STATSAMPLEBY.nested_types_by_name["Fraction"]
_NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
_NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
_NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
@@ -449,6 +451,27 @@ StatFreqItems = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(StatFreqItems)
+StatSampleBy = _reflection.GeneratedProtocolMessageType(
+ "StatSampleBy",
+ (_message.Message,),
+ {
+ "Fraction": _reflection.GeneratedProtocolMessageType(
+ "Fraction",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _STATSAMPLEBY_FRACTION,
+ "__module__": "spark.connect.relations_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.StatSampleBy.Fraction)
+ },
+ ),
+ "DESCRIPTOR": _STATSAMPLEBY,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.StatSampleBy)
+ },
+)
+_sym_db.RegisterMessage(StatSampleBy)
+_sym_db.RegisterMessage(StatSampleBy.Fraction)
+
NAFill = _reflection.GeneratedProtocolMessageType(
"NAFill",
(_message.Message,),
@@ -588,99 +611,103 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options =
b"8\001"
_RELATION._serialized_start = 165
- _RELATION._serialized_end = 2518
- _UNKNOWN._serialized_start = 2520
- _UNKNOWN._serialized_end = 2529
- _RELATIONCOMMON._serialized_start = 2531
- _RELATIONCOMMON._serialized_end = 2580
- _SQL._serialized_start = 2582
- _SQL._serialized_end = 2609
- _READ._serialized_start = 2612
- _READ._serialized_end = 3038
- _READ_NAMEDTABLE._serialized_start = 2754
- _READ_NAMEDTABLE._serialized_end = 2815
- _READ_DATASOURCE._serialized_start = 2818
- _READ_DATASOURCE._serialized_end = 3025
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2956
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3014
- _PROJECT._serialized_start = 3040
- _PROJECT._serialized_end = 3157
- _FILTER._serialized_start = 3159
- _FILTER._serialized_end = 3271
- _JOIN._serialized_start = 3274
- _JOIN._serialized_end = 3745
- _JOIN_JOINTYPE._serialized_start = 3537
- _JOIN_JOINTYPE._serialized_end = 3745
- _SETOPERATION._serialized_start = 3748
- _SETOPERATION._serialized_end = 4144
- _SETOPERATION_SETOPTYPE._serialized_start = 4007
- _SETOPERATION_SETOPTYPE._serialized_end = 4121
- _LIMIT._serialized_start = 4146
- _LIMIT._serialized_end = 4222
- _OFFSET._serialized_start = 4224
- _OFFSET._serialized_end = 4303
- _TAIL._serialized_start = 4305
- _TAIL._serialized_end = 4380
- _AGGREGATE._serialized_start = 4383
- _AGGREGATE._serialized_end = 4965
- _AGGREGATE_PIVOT._serialized_start = 4722
- _AGGREGATE_PIVOT._serialized_end = 4833
- _AGGREGATE_GROUPTYPE._serialized_start = 4836
- _AGGREGATE_GROUPTYPE._serialized_end = 4965
- _SORT._serialized_start = 4968
- _SORT._serialized_end = 5128
- _DROP._serialized_start = 5130
- _DROP._serialized_end = 5230
- _DEDUPLICATE._serialized_start = 5233
- _DEDUPLICATE._serialized_end = 5404
- _LOCALRELATION._serialized_start = 5407
- _LOCALRELATION._serialized_end = 5544
- _SAMPLE._serialized_start = 5547
- _SAMPLE._serialized_end = 5820
- _RANGE._serialized_start = 5823
- _RANGE._serialized_end = 5968
- _SUBQUERYALIAS._serialized_start = 5970
- _SUBQUERYALIAS._serialized_end = 6084
- _REPARTITION._serialized_start = 6087
- _REPARTITION._serialized_end = 6229
- _SHOWSTRING._serialized_start = 6232
- _SHOWSTRING._serialized_end = 6374
- _STATSUMMARY._serialized_start = 6376
- _STATSUMMARY._serialized_end = 6468
- _STATDESCRIBE._serialized_start = 6470
- _STATDESCRIBE._serialized_end = 6551
- _STATCROSSTAB._serialized_start = 6553
- _STATCROSSTAB._serialized_end = 6654
- _STATCOV._serialized_start = 6656
- _STATCOV._serialized_end = 6752
- _STATCORR._serialized_start = 6755
- _STATCORR._serialized_end = 6892
- _STATAPPROXQUANTILE._serialized_start = 6895
- _STATAPPROXQUANTILE._serialized_end = 7059
- _STATFREQITEMS._serialized_start = 7061
- _STATFREQITEMS._serialized_end = 7186
- _NAFILL._serialized_start = 7189
- _NAFILL._serialized_end = 7323
- _NADROP._serialized_start = 7326
- _NADROP._serialized_end = 7460
- _NAREPLACE._serialized_start = 7463
- _NAREPLACE._serialized_end = 7759
- _NAREPLACE_REPLACEMENT._serialized_start = 7618
- _NAREPLACE_REPLACEMENT._serialized_end = 7759
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7761
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7875
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7878
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8137
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
8070
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8137
- _WITHCOLUMNS._serialized_start = 8140
- _WITHCOLUMNS._serialized_end = 8271
- _HINT._serialized_start = 8274
- _HINT._serialized_end = 8414
- _UNPIVOT._serialized_start = 8417
- _UNPIVOT._serialized_end = 8663
- _TOSCHEMA._serialized_start = 8665
- _TOSCHEMA._serialized_end = 8771
- _REPARTITIONBYEXPRESSION._serialized_start = 8774
- _REPARTITIONBYEXPRESSION._serialized_end = 8977
+ _RELATION._serialized_end = 2578
+ _UNKNOWN._serialized_start = 2580
+ _UNKNOWN._serialized_end = 2589
+ _RELATIONCOMMON._serialized_start = 2591
+ _RELATIONCOMMON._serialized_end = 2640
+ _SQL._serialized_start = 2642
+ _SQL._serialized_end = 2669
+ _READ._serialized_start = 2672
+ _READ._serialized_end = 3098
+ _READ_NAMEDTABLE._serialized_start = 2814
+ _READ_NAMEDTABLE._serialized_end = 2875
+ _READ_DATASOURCE._serialized_start = 2878
+ _READ_DATASOURCE._serialized_end = 3085
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3016
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3074
+ _PROJECT._serialized_start = 3100
+ _PROJECT._serialized_end = 3217
+ _FILTER._serialized_start = 3219
+ _FILTER._serialized_end = 3331
+ _JOIN._serialized_start = 3334
+ _JOIN._serialized_end = 3805
+ _JOIN_JOINTYPE._serialized_start = 3597
+ _JOIN_JOINTYPE._serialized_end = 3805
+ _SETOPERATION._serialized_start = 3808
+ _SETOPERATION._serialized_end = 4204
+ _SETOPERATION_SETOPTYPE._serialized_start = 4067
+ _SETOPERATION_SETOPTYPE._serialized_end = 4181
+ _LIMIT._serialized_start = 4206
+ _LIMIT._serialized_end = 4282
+ _OFFSET._serialized_start = 4284
+ _OFFSET._serialized_end = 4363
+ _TAIL._serialized_start = 4365
+ _TAIL._serialized_end = 4440
+ _AGGREGATE._serialized_start = 4443
+ _AGGREGATE._serialized_end = 5025
+ _AGGREGATE_PIVOT._serialized_start = 4782
+ _AGGREGATE_PIVOT._serialized_end = 4893
+ _AGGREGATE_GROUPTYPE._serialized_start = 4896
+ _AGGREGATE_GROUPTYPE._serialized_end = 5025
+ _SORT._serialized_start = 5028
+ _SORT._serialized_end = 5188
+ _DROP._serialized_start = 5190
+ _DROP._serialized_end = 5290
+ _DEDUPLICATE._serialized_start = 5293
+ _DEDUPLICATE._serialized_end = 5464
+ _LOCALRELATION._serialized_start = 5467
+ _LOCALRELATION._serialized_end = 5604
+ _SAMPLE._serialized_start = 5607
+ _SAMPLE._serialized_end = 5880
+ _RANGE._serialized_start = 5883
+ _RANGE._serialized_end = 6028
+ _SUBQUERYALIAS._serialized_start = 6030
+ _SUBQUERYALIAS._serialized_end = 6144
+ _REPARTITION._serialized_start = 6147
+ _REPARTITION._serialized_end = 6289
+ _SHOWSTRING._serialized_start = 6292
+ _SHOWSTRING._serialized_end = 6434
+ _STATSUMMARY._serialized_start = 6436
+ _STATSUMMARY._serialized_end = 6528
+ _STATDESCRIBE._serialized_start = 6530
+ _STATDESCRIBE._serialized_end = 6611
+ _STATCROSSTAB._serialized_start = 6613
+ _STATCROSSTAB._serialized_end = 6714
+ _STATCOV._serialized_start = 6716
+ _STATCOV._serialized_end = 6812
+ _STATCORR._serialized_start = 6815
+ _STATCORR._serialized_end = 6952
+ _STATAPPROXQUANTILE._serialized_start = 6955
+ _STATAPPROXQUANTILE._serialized_end = 7119
+ _STATFREQITEMS._serialized_start = 7121
+ _STATFREQITEMS._serialized_end = 7246
+ _STATSAMPLEBY._serialized_start = 7249
+ _STATSAMPLEBY._serialized_end = 7558
+ _STATSAMPLEBY_FRACTION._serialized_start = 7450
+ _STATSAMPLEBY_FRACTION._serialized_end = 7549
+ _NAFILL._serialized_start = 7561
+ _NAFILL._serialized_end = 7695
+ _NADROP._serialized_start = 7698
+ _NADROP._serialized_end = 7832
+ _NAREPLACE._serialized_start = 7835
+ _NAREPLACE._serialized_end = 8131
+ _NAREPLACE_REPLACEMENT._serialized_start = 7990
+ _NAREPLACE_REPLACEMENT._serialized_end = 8131
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8133
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8247
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8250
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8509
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
8442
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8509
+ _WITHCOLUMNS._serialized_start = 8512
+ _WITHCOLUMNS._serialized_end = 8643
+ _HINT._serialized_start = 8646
+ _HINT._serialized_end = 8786
+ _UNPIVOT._serialized_start = 8789
+ _UNPIVOT._serialized_end = 9035
+ _TOSCHEMA._serialized_start = 9037
+ _TOSCHEMA._serialized_end = 9143
+ _REPARTITIONBYEXPRESSION._serialized_start = 9146
+ _REPARTITIONBYEXPRESSION._serialized_end = 9349
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 96915de60dc..500b9d8804c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -99,6 +99,7 @@ class Relation(google.protobuf.message.Message):
CORR_FIELD_NUMBER: builtins.int
APPROX_QUANTILE_FIELD_NUMBER: builtins.int
FREQ_ITEMS_FIELD_NUMBER: builtins.int
+ SAMPLE_BY_FIELD_NUMBER: builtins.int
CATALOG_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@@ -179,6 +180,8 @@ class Relation(google.protobuf.message.Message):
@property
def freq_items(self) -> global___StatFreqItems: ...
@property
+ def sample_by(self) -> global___StatSampleBy: ...
+ @property
def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog:
"""Catalog API (experimental / unstable)"""
@property
@@ -228,6 +231,7 @@ class Relation(google.protobuf.message.Message):
corr: global___StatCorr | None = ...,
approx_quantile: global___StatApproxQuantile | None = ...,
freq_items: global___StatFreqItems | None = ...,
+ sample_by: global___StatSampleBy | None = ...,
catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
unknown: global___Unknown | None = ...,
@@ -295,6 +299,8 @@ class Relation(google.protobuf.message.Message):
b"replace",
"sample",
b"sample",
+ "sample_by",
+ b"sample_by",
"set_op",
b"set_op",
"show_string",
@@ -382,6 +388,8 @@ class Relation(google.protobuf.message.Message):
b"replace",
"sample",
b"sample",
+ "sample_by",
+ b"sample_by",
"set_op",
b"set_op",
"show_string",
@@ -445,6 +453,7 @@ class Relation(google.protobuf.message.Message):
"corr",
"approx_quantile",
"freq_items",
+ "sample_by",
"catalog",
"extension",
"unknown",
@@ -1910,6 +1919,94 @@ class StatFreqItems(google.protobuf.message.Message):
global___StatFreqItems = StatFreqItems
+class StatSampleBy(google.protobuf.message.Message):
+ """Returns a stratified sample without replacement based on the fraction
+ given on each stratum.
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class Fraction(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ STRATUM_FIELD_NUMBER: builtins.int
+ FRACTION_FIELD_NUMBER: builtins.int
+ @property
+ def stratum(self) ->
pyspark.sql.connect.proto.expressions_pb2.Expression.Literal:
+ """(Required) The stratum."""
+ fraction: builtins.float
+ """(Required) The fraction value. Must be in [0, 1]."""
+ def __init__(
+ self,
+ *,
+ stratum:
pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ...,
+ fraction: builtins.float = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["stratum", b"stratum"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal["fraction", b"fraction",
"stratum", b"stratum"],
+ ) -> None: ...
+
+ INPUT_FIELD_NUMBER: builtins.int
+ COL_FIELD_NUMBER: builtins.int
+ FRACTIONS_FIELD_NUMBER: builtins.int
+ SEED_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ @property
+ def col(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression:
+ """(Required) The column that defines strata."""
+ @property
+ def fractions(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___StatSampleBy.Fraction
+ ]:
+ """(Required) Sampling fraction for each stratum.
+
+ If a stratum is not specified, we treat its fraction as zero.
+ """
+ seed: builtins.int
+ """(Optional) The random seed."""
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ col: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ...,
+ fractions: collections.abc.Iterable[global___StatSampleBy.Fraction] |
None = ...,
+ seed: builtins.int | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_seed", b"_seed", "col", b"col", "input", b"input", "seed",
b"seed"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_seed",
+ b"_seed",
+ "col",
+ b"col",
+ "fractions",
+ b"fractions",
+ "input",
+ b"input",
+ "seed",
+ b"seed",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]
+ ) -> typing_extensions.Literal["seed"] | None: ...
+
+global___StatSampleBy = StatSampleBy
+
class NAFill(google.protobuf.message.Message):
"""Replaces null values.
It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to
compute the results.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0b615d2e32a..6a65e412dfd 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1194,6 +1194,34 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
):
self.connect.read.table(self.tbl_name2).stat.freqItems("col1")
+ def test_stat_sample_by(self):
+ # SPARK-41069: Test stat.sample_by
+
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ cdf = self.connect.range(0, 100).select((CF.col("id") %
3).alias("key"))
+ sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key"))
+
+ self.assert_eq(
+ cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
+ .groupBy("key")
+ .agg(CF.count(CF.lit(1)))
+ .orderBy("key")
+ .toPandas(),
+ sdf.sampleBy(sdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
+ .groupBy("key")
+ .agg(SF.count(SF.lit(1)))
+ .orderBy("key")
+ .toPandas(),
+ )
+
+ with self.assertRaisesRegex(TypeError, "key must be float, int, or
string"):
+ cdf.stat.sampleBy(cdf.key, fractions={0: 0.1, None: 0.2}, seed=0)
+
+ with self.assertRaises(SparkConnectException):
+ cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 1.2}, seed=0).show()
+
def test_repr(self):
# SPARK-41213: Test the __repr__ method
query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]