This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d9c604ec932 [SPARK-41066][CONNECT][PYTHON] Implement 
`DataFrame.sampleBy ` and `DataFrame.stat.sampleBy `
d9c604ec932 is described below

commit d9c604ec9322117fce0c9b3302c3cd73f5d16df7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 2 09:31:21 2023 +0900

    [SPARK-41066][CONNECT][PYTHON] Implement `DataFrame.sampleBy ` and 
`DataFrame.stat.sampleBy `
    
    ### What changes were proposed in this pull request?
     Implement `DataFrame.sampleBy ` and `DataFrame.stat.sampleBy `
    
    ### Why are the changes needed?
    For API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT
    
    Closes #39328 from zhengruifeng/connect_df_sampleby.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/protobuf/spark/connect/relations.proto    |  29 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  24 ++-
 python/pyspark/sql/connect/dataframe.py            |  27 +++
 python/pyspark/sql/connect/plan.py                 |  44 +++++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 219 ++++++++++++---------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  97 +++++++++
 .../sql/tests/connect/test_connect_basic.py        |  28 +++
 7 files changed, 371 insertions(+), 97 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index db3565eda61..2d834f3fd8c 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -74,6 +74,7 @@ message Relation {
     StatCorr corr = 104;
     StatApproxQuantile approx_quantile = 105;
     StatFreqItems freq_items = 106;
+    StatSampleBy sample_by = 107;
 
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
@@ -546,6 +547,34 @@ message StatFreqItems {
   optional double support = 3;
 }
 
+
+// Returns a stratified sample without replacement based on the fraction
+// given on each stratum.
+message StatSampleBy {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Required) The column that defines strata.
+  Expression col = 2;
+
+  // (Required) Sampling fraction for each stratum.
+  //
+  // If a stratum is not specified, we treat its fraction as zero.
+  repeated Fraction fractions = 3;
+
+  // (Optional) The random seed.
+  optional int64 seed = 5;
+
+  message Fraction {
+    // (Required) The stratum.
+    Expression.Literal stratum = 1;
+
+    // (Required) The fraction value. Must be in [0, 1].
+    double fraction = 2;
+  }
+}
+
+
 // Replaces null values.
 // It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to 
compute the results.
 // Following 3 parameter combinations are supported:
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dcfdc3f8b52..d7e2908a1c5 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -98,6 +98,8 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Relation.RelTypeCase.CROSSTAB =>
         transformStatCrosstab(rel.getCrosstab)
       case proto.Relation.RelTypeCase.FREQ_ITEMS => 
transformStatFreqItems(rel.getFreqItems)
+      case proto.Relation.RelTypeCase.SAMPLE_BY =>
+        transformStatSampleBy(rel.getSampleBy)
       case proto.Relation.RelTypeCase.TO_SCHEMA => 
transformToSchema(rel.getToSchema)
       case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_SAME_LENGTH_NAMES =>
         
transformRenameColumnsBySamelenghtNames(rel.getRenameColumnsBySameLengthNames)
@@ -419,6 +421,26 @@ class SparkConnectPlanner(session: SparkSession) {
     }
   }
 
+  private def transformStatSampleBy(rel: proto.StatSampleBy): LogicalPlan = {
+    val fractions = mutable.Map.empty[Any, Double]
+    rel.getFractionsList.asScala.toSeq.foreach { protoFraction =>
+      val stratum = transformLiteral(protoFraction.getStratum) match {
+        case Literal(s, StringType) if s != null => s.toString
+        case literal => literal.value
+      }
+      fractions.update(stratum, protoFraction.getFraction)
+    }
+
+    Dataset
+      .ofRows(session, transformRelation(rel.getInput))
+      .stat
+      .sampleBy(
+        col = Column(transformExpression(rel.getCol)),
+        fractions = fractions.toMap,
+        seed = if (rel.hasSeed) rel.getSeed else Utils.random.nextLong)
+      .logicalPlan
+  }
+
   private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
     val schema = DataTypeProtoConverter.toCatalystType(rel.getSchema)
     assert(schema.isInstanceOf[StructType])
@@ -697,7 +719,7 @@ class SparkConnectPlanner(session: SparkSession) {
    * @return
    *   Expression
    */
-  private def transformLiteral(lit: proto.Expression.Literal): Expression = {
+  private def transformLiteral(lit: proto.Expression.Literal): Literal = {
     toCatalystExpression(lit)
   }
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index c5ab22b34bd..95582e86390 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -952,6 +952,26 @@ class DataFrame:
 
     freqItems.__doc__ = PySparkDataFrame.freqItems.__doc__
 
+    def sampleBy(
+        self, col: "ColumnOrName", fractions: Dict[Any, float], seed: 
Optional[int] = None
+    ) -> "DataFrame":
+        if not isinstance(col, (Column, str)):
+            raise TypeError("col must be a string or a column, but got %r" % 
type(col))
+        if not isinstance(fractions, dict):
+            raise TypeError("fractions must be a dict but got %r" % 
type(fractions))
+        for k, v in fractions.items():
+            if not isinstance(k, (float, int, str)):
+                raise TypeError("key must be float, int, or string, but got 
%r" % type(k))
+            fractions[k] = float(v)
+        seed = seed if seed is not None else random.randint(0, sys.maxsize)
+
+        return DataFrame.withPlan(
+            plan.StatSampleBy(child=self._plan, col=col, fractions=fractions, 
seed=seed),
+            session=self._session,
+        )
+
+    sampleBy.__doc__ = PySparkDataFrame.sampleBy.__doc__
+
     def _get_alias(self) -> Optional[str]:
         p = self._plan
         while p is not None:
@@ -1344,5 +1364,12 @@ class DataFrameStatFunctions:
 
     freqItems.__doc__ = DataFrame.freqItems.__doc__
 
+    def sampleBy(
+        self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None
+    ) -> DataFrame:
+        return self.df.sampleBy(col, fractions, seed)
+
+    sampleBy.__doc__ = DataFrame.sampleBy.__doc__
+
 
 DataFrameStatFunctions.__doc__ = PySparkDataFrameStatFunctions.__doc__
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index f567d88137a..f10687cc82e 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1140,6 +1140,50 @@ class StatFreqItems(LogicalPlan):
         return plan
 
 
+class StatSampleBy(LogicalPlan):
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        col: "ColumnOrName",
+        fractions: Dict[Any, float],
+        seed: Optional[int],
+    ) -> None:
+        super().__init__(child)
+
+        assert col is not None and isinstance(col, (Column, str))
+
+        assert fractions is not None and isinstance(fractions, dict)
+        for k, v in fractions.items():
+            assert v is not None and isinstance(v, float)
+
+        assert seed is None or isinstance(seed, int)
+
+        if isinstance(col, Column):
+            self._col = col
+        else:
+            self._col = Column(ColumnReference(col))
+
+        self._fractions = fractions
+
+        self._seed = seed
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+
+        plan = proto.Relation()
+        plan.sample_by.input.CopyFrom(self._child.plan(session))
+        plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session))
+        if len(self._fractions) > 0:
+            for k, v in self._fractions.items():
+                fraction = proto.StatSampleBy.Fraction()
+                
fraction.stratum.CopyFrom(LiteralExpression._from_value(k).to_plan(session).literal)
+                fraction.fraction = float(v)
+                plan.sample_by.fractions.append(fraction)
+        if self._seed is not None:
+            plan.sample_by.seed = self._seed
+        return plan
+
+
 class StatCorr(LogicalPlan):
     def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str, 
method: str) -> None:
         super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 6e2904b0294..cf0f2eb3513 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -73,6 +73,8 @@ _STATCOV = DESCRIPTOR.message_types_by_name["StatCov"]
 _STATCORR = DESCRIPTOR.message_types_by_name["StatCorr"]
 _STATAPPROXQUANTILE = DESCRIPTOR.message_types_by_name["StatApproxQuantile"]
 _STATFREQITEMS = DESCRIPTOR.message_types_by_name["StatFreqItems"]
+_STATSAMPLEBY = DESCRIPTOR.message_types_by_name["StatSampleBy"]
+_STATSAMPLEBY_FRACTION = _STATSAMPLEBY.nested_types_by_name["Fraction"]
 _NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
 _NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
 _NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
@@ -449,6 +451,27 @@ StatFreqItems = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(StatFreqItems)
 
+StatSampleBy = _reflection.GeneratedProtocolMessageType(
+    "StatSampleBy",
+    (_message.Message,),
+    {
+        "Fraction": _reflection.GeneratedProtocolMessageType(
+            "Fraction",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _STATSAMPLEBY_FRACTION,
+                "__module__": "spark.connect.relations_pb2"
+                # 
@@protoc_insertion_point(class_scope:spark.connect.StatSampleBy.Fraction)
+            },
+        ),
+        "DESCRIPTOR": _STATSAMPLEBY,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.StatSampleBy)
+    },
+)
+_sym_db.RegisterMessage(StatSampleBy)
+_sym_db.RegisterMessage(StatSampleBy.Fraction)
+
 NAFill = _reflection.GeneratedProtocolMessageType(
     "NAFill",
     (_message.Message,),
@@ -588,99 +611,103 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = 
b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 2518
-    _UNKNOWN._serialized_start = 2520
-    _UNKNOWN._serialized_end = 2529
-    _RELATIONCOMMON._serialized_start = 2531
-    _RELATIONCOMMON._serialized_end = 2580
-    _SQL._serialized_start = 2582
-    _SQL._serialized_end = 2609
-    _READ._serialized_start = 2612
-    _READ._serialized_end = 3038
-    _READ_NAMEDTABLE._serialized_start = 2754
-    _READ_NAMEDTABLE._serialized_end = 2815
-    _READ_DATASOURCE._serialized_start = 2818
-    _READ_DATASOURCE._serialized_end = 3025
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2956
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3014
-    _PROJECT._serialized_start = 3040
-    _PROJECT._serialized_end = 3157
-    _FILTER._serialized_start = 3159
-    _FILTER._serialized_end = 3271
-    _JOIN._serialized_start = 3274
-    _JOIN._serialized_end = 3745
-    _JOIN_JOINTYPE._serialized_start = 3537
-    _JOIN_JOINTYPE._serialized_end = 3745
-    _SETOPERATION._serialized_start = 3748
-    _SETOPERATION._serialized_end = 4144
-    _SETOPERATION_SETOPTYPE._serialized_start = 4007
-    _SETOPERATION_SETOPTYPE._serialized_end = 4121
-    _LIMIT._serialized_start = 4146
-    _LIMIT._serialized_end = 4222
-    _OFFSET._serialized_start = 4224
-    _OFFSET._serialized_end = 4303
-    _TAIL._serialized_start = 4305
-    _TAIL._serialized_end = 4380
-    _AGGREGATE._serialized_start = 4383
-    _AGGREGATE._serialized_end = 4965
-    _AGGREGATE_PIVOT._serialized_start = 4722
-    _AGGREGATE_PIVOT._serialized_end = 4833
-    _AGGREGATE_GROUPTYPE._serialized_start = 4836
-    _AGGREGATE_GROUPTYPE._serialized_end = 4965
-    _SORT._serialized_start = 4968
-    _SORT._serialized_end = 5128
-    _DROP._serialized_start = 5130
-    _DROP._serialized_end = 5230
-    _DEDUPLICATE._serialized_start = 5233
-    _DEDUPLICATE._serialized_end = 5404
-    _LOCALRELATION._serialized_start = 5407
-    _LOCALRELATION._serialized_end = 5544
-    _SAMPLE._serialized_start = 5547
-    _SAMPLE._serialized_end = 5820
-    _RANGE._serialized_start = 5823
-    _RANGE._serialized_end = 5968
-    _SUBQUERYALIAS._serialized_start = 5970
-    _SUBQUERYALIAS._serialized_end = 6084
-    _REPARTITION._serialized_start = 6087
-    _REPARTITION._serialized_end = 6229
-    _SHOWSTRING._serialized_start = 6232
-    _SHOWSTRING._serialized_end = 6374
-    _STATSUMMARY._serialized_start = 6376
-    _STATSUMMARY._serialized_end = 6468
-    _STATDESCRIBE._serialized_start = 6470
-    _STATDESCRIBE._serialized_end = 6551
-    _STATCROSSTAB._serialized_start = 6553
-    _STATCROSSTAB._serialized_end = 6654
-    _STATCOV._serialized_start = 6656
-    _STATCOV._serialized_end = 6752
-    _STATCORR._serialized_start = 6755
-    _STATCORR._serialized_end = 6892
-    _STATAPPROXQUANTILE._serialized_start = 6895
-    _STATAPPROXQUANTILE._serialized_end = 7059
-    _STATFREQITEMS._serialized_start = 7061
-    _STATFREQITEMS._serialized_end = 7186
-    _NAFILL._serialized_start = 7189
-    _NAFILL._serialized_end = 7323
-    _NADROP._serialized_start = 7326
-    _NADROP._serialized_end = 7460
-    _NAREPLACE._serialized_start = 7463
-    _NAREPLACE._serialized_end = 7759
-    _NAREPLACE_REPLACEMENT._serialized_start = 7618
-    _NAREPLACE_REPLACEMENT._serialized_end = 7759
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7761
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7875
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7878
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8137
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
8070
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8137
-    _WITHCOLUMNS._serialized_start = 8140
-    _WITHCOLUMNS._serialized_end = 8271
-    _HINT._serialized_start = 8274
-    _HINT._serialized_end = 8414
-    _UNPIVOT._serialized_start = 8417
-    _UNPIVOT._serialized_end = 8663
-    _TOSCHEMA._serialized_start = 8665
-    _TOSCHEMA._serialized_end = 8771
-    _REPARTITIONBYEXPRESSION._serialized_start = 8774
-    _REPARTITIONBYEXPRESSION._serialized_end = 8977
+    _RELATION._serialized_end = 2578
+    _UNKNOWN._serialized_start = 2580
+    _UNKNOWN._serialized_end = 2589
+    _RELATIONCOMMON._serialized_start = 2591
+    _RELATIONCOMMON._serialized_end = 2640
+    _SQL._serialized_start = 2642
+    _SQL._serialized_end = 2669
+    _READ._serialized_start = 2672
+    _READ._serialized_end = 3098
+    _READ_NAMEDTABLE._serialized_start = 2814
+    _READ_NAMEDTABLE._serialized_end = 2875
+    _READ_DATASOURCE._serialized_start = 2878
+    _READ_DATASOURCE._serialized_end = 3085
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3016
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3074
+    _PROJECT._serialized_start = 3100
+    _PROJECT._serialized_end = 3217
+    _FILTER._serialized_start = 3219
+    _FILTER._serialized_end = 3331
+    _JOIN._serialized_start = 3334
+    _JOIN._serialized_end = 3805
+    _JOIN_JOINTYPE._serialized_start = 3597
+    _JOIN_JOINTYPE._serialized_end = 3805
+    _SETOPERATION._serialized_start = 3808
+    _SETOPERATION._serialized_end = 4204
+    _SETOPERATION_SETOPTYPE._serialized_start = 4067
+    _SETOPERATION_SETOPTYPE._serialized_end = 4181
+    _LIMIT._serialized_start = 4206
+    _LIMIT._serialized_end = 4282
+    _OFFSET._serialized_start = 4284
+    _OFFSET._serialized_end = 4363
+    _TAIL._serialized_start = 4365
+    _TAIL._serialized_end = 4440
+    _AGGREGATE._serialized_start = 4443
+    _AGGREGATE._serialized_end = 5025
+    _AGGREGATE_PIVOT._serialized_start = 4782
+    _AGGREGATE_PIVOT._serialized_end = 4893
+    _AGGREGATE_GROUPTYPE._serialized_start = 4896
+    _AGGREGATE_GROUPTYPE._serialized_end = 5025
+    _SORT._serialized_start = 5028
+    _SORT._serialized_end = 5188
+    _DROP._serialized_start = 5190
+    _DROP._serialized_end = 5290
+    _DEDUPLICATE._serialized_start = 5293
+    _DEDUPLICATE._serialized_end = 5464
+    _LOCALRELATION._serialized_start = 5467
+    _LOCALRELATION._serialized_end = 5604
+    _SAMPLE._serialized_start = 5607
+    _SAMPLE._serialized_end = 5880
+    _RANGE._serialized_start = 5883
+    _RANGE._serialized_end = 6028
+    _SUBQUERYALIAS._serialized_start = 6030
+    _SUBQUERYALIAS._serialized_end = 6144
+    _REPARTITION._serialized_start = 6147
+    _REPARTITION._serialized_end = 6289
+    _SHOWSTRING._serialized_start = 6292
+    _SHOWSTRING._serialized_end = 6434
+    _STATSUMMARY._serialized_start = 6436
+    _STATSUMMARY._serialized_end = 6528
+    _STATDESCRIBE._serialized_start = 6530
+    _STATDESCRIBE._serialized_end = 6611
+    _STATCROSSTAB._serialized_start = 6613
+    _STATCROSSTAB._serialized_end = 6714
+    _STATCOV._serialized_start = 6716
+    _STATCOV._serialized_end = 6812
+    _STATCORR._serialized_start = 6815
+    _STATCORR._serialized_end = 6952
+    _STATAPPROXQUANTILE._serialized_start = 6955
+    _STATAPPROXQUANTILE._serialized_end = 7119
+    _STATFREQITEMS._serialized_start = 7121
+    _STATFREQITEMS._serialized_end = 7246
+    _STATSAMPLEBY._serialized_start = 7249
+    _STATSAMPLEBY._serialized_end = 7558
+    _STATSAMPLEBY_FRACTION._serialized_start = 7450
+    _STATSAMPLEBY_FRACTION._serialized_end = 7549
+    _NAFILL._serialized_start = 7561
+    _NAFILL._serialized_end = 7695
+    _NADROP._serialized_start = 7698
+    _NADROP._serialized_end = 7832
+    _NAREPLACE._serialized_start = 7835
+    _NAREPLACE._serialized_end = 8131
+    _NAREPLACE_REPLACEMENT._serialized_start = 7990
+    _NAREPLACE_REPLACEMENT._serialized_end = 8131
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8133
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8247
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8250
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8509
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
8442
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8509
+    _WITHCOLUMNS._serialized_start = 8512
+    _WITHCOLUMNS._serialized_end = 8643
+    _HINT._serialized_start = 8646
+    _HINT._serialized_end = 8786
+    _UNPIVOT._serialized_start = 8789
+    _UNPIVOT._serialized_end = 9035
+    _TOSCHEMA._serialized_start = 9037
+    _TOSCHEMA._serialized_end = 9143
+    _REPARTITIONBYEXPRESSION._serialized_start = 9146
+    _REPARTITIONBYEXPRESSION._serialized_end = 9349
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 96915de60dc..500b9d8804c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -99,6 +99,7 @@ class Relation(google.protobuf.message.Message):
     CORR_FIELD_NUMBER: builtins.int
     APPROX_QUANTILE_FIELD_NUMBER: builtins.int
     FREQ_ITEMS_FIELD_NUMBER: builtins.int
+    SAMPLE_BY_FIELD_NUMBER: builtins.int
     CATALOG_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
@@ -179,6 +180,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def freq_items(self) -> global___StatFreqItems: ...
     @property
+    def sample_by(self) -> global___StatSampleBy: ...
+    @property
     def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog:
         """Catalog API (experimental / unstable)"""
     @property
@@ -228,6 +231,7 @@ class Relation(google.protobuf.message.Message):
         corr: global___StatCorr | None = ...,
         approx_quantile: global___StatApproxQuantile | None = ...,
         freq_items: global___StatFreqItems | None = ...,
+        sample_by: global___StatSampleBy | None = ...,
         catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
         unknown: global___Unknown | None = ...,
@@ -295,6 +299,8 @@ class Relation(google.protobuf.message.Message):
             b"replace",
             "sample",
             b"sample",
+            "sample_by",
+            b"sample_by",
             "set_op",
             b"set_op",
             "show_string",
@@ -382,6 +388,8 @@ class Relation(google.protobuf.message.Message):
             b"replace",
             "sample",
             b"sample",
+            "sample_by",
+            b"sample_by",
             "set_op",
             b"set_op",
             "show_string",
@@ -445,6 +453,7 @@ class Relation(google.protobuf.message.Message):
         "corr",
         "approx_quantile",
         "freq_items",
+        "sample_by",
         "catalog",
         "extension",
         "unknown",
@@ -1910,6 +1919,94 @@ class StatFreqItems(google.protobuf.message.Message):
 
 global___StatFreqItems = StatFreqItems
 
+class StatSampleBy(google.protobuf.message.Message):
+    """Returns a stratified sample without replacement based on the fraction
+    given on each stratum.
+    """
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class Fraction(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        STRATUM_FIELD_NUMBER: builtins.int
+        FRACTION_FIELD_NUMBER: builtins.int
+        @property
+        def stratum(self) -> 
pyspark.sql.connect.proto.expressions_pb2.Expression.Literal:
+            """(Required) The stratum."""
+        fraction: builtins.float
+        """(Required) The fraction value. Must be in [0, 1]."""
+        def __init__(
+            self,
+            *,
+            stratum: 
pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ...,
+            fraction: builtins.float = ...,
+        ) -> None: ...
+        def HasField(
+            self, field_name: typing_extensions.Literal["stratum", b"stratum"]
+        ) -> builtins.bool: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal["fraction", b"fraction", 
"stratum", b"stratum"],
+        ) -> None: ...
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COL_FIELD_NUMBER: builtins.int
+    FRACTIONS_FIELD_NUMBER: builtins.int
+    SEED_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    @property
+    def col(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression:
+        """(Required) The column that defines strata."""
+    @property
+    def fractions(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        global___StatSampleBy.Fraction
+    ]:
+        """(Required) Sampling fraction for each stratum.
+
+        If a stratum is not specified, we treat its fraction as zero.
+        """
+    seed: builtins.int
+    """(Optional) The random seed."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        col: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ...,
+        fractions: collections.abc.Iterable[global___StatSampleBy.Fraction] | 
None = ...,
+        seed: builtins.int | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_seed", b"_seed", "col", b"col", "input", b"input", "seed", 
b"seed"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_seed",
+            b"_seed",
+            "col",
+            b"col",
+            "fractions",
+            b"fractions",
+            "input",
+            b"input",
+            "seed",
+            b"seed",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]
+    ) -> typing_extensions.Literal["seed"] | None: ...
+
+global___StatSampleBy = StatSampleBy
+
 class NAFill(google.protobuf.message.Message):
     """Replaces null values.
     It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to 
compute the results.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0b615d2e32a..6a65e412dfd 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1194,6 +1194,34 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         ):
             self.connect.read.table(self.tbl_name2).stat.freqItems("col1")
 
+    def test_stat_sample_by(self):
+        # SPARK-41069: Test stat.sample_by
+
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        cdf = self.connect.range(0, 100).select((CF.col("id") % 
3).alias("key"))
+        sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key"))
+
+        self.assert_eq(
+            cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
+            .groupBy("key")
+            .agg(CF.count(CF.lit(1)))
+            .orderBy("key")
+            .toPandas(),
+            sdf.sampleBy(sdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
+            .groupBy("key")
+            .agg(SF.count(SF.lit(1)))
+            .orderBy("key")
+            .toPandas(),
+        )
+
+        with self.assertRaisesRegex(TypeError, "key must be float, int, or 
string"):
+            cdf.stat.sampleBy(cdf.key, fractions={0: 0.1, None: 0.2}, seed=0)
+
+        with self.assertRaises(SparkConnectException):
+            cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 1.2}, seed=0).show()
+
     def test_repr(self):
         # SPARK-41213: Test the __repr__ method
         query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to