This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new f9398aa  Add Sample functionality in DataFrame
f9398aa is described below

commit f9398aa124b5effaed70f9b6b804cb82b94d40fb
Author: vatsal <[email protected]>
AuthorDate: Thu Jan 2 20:30:37 2025 +0100

    Add Sample functionality in DataFrame
    
    Signed-off-by: Vatsal <vatsal.v.anandgmail.com>
    
    ### What changes were proposed in this pull request?
    This PR adds support for the `sample` DF function.
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    new DF functions
    
    ### How was this patch tested?
    Added e2e tests
    
    Closes #84 from imvtsl/issue.
    
    Lead-authored-by: vatsal <[email protected]>
    Co-authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go | 144 +++++++++++++++++++++++++++
 spark/sql/dataframe.go                       |  53 +++++++++-
 spark/sql/dataframestatfunctions.go          |  24 +++++
 3 files changed, 220 insertions(+), 1 deletion(-)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index 6220b8e..720cc98 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -805,6 +805,150 @@ func TestDataFrame_WithOption(t *testing.T) {
        assert.Equal(t, int64(10), c)
 }
 
+func TestDataFrame_Sample(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       testCases := []struct {
+               name     string
+               fraction float64
+       }{
+               {
+                       name:     "Default behavior",
+                       fraction: 0.1,
+               },
+               {
+                       name:     "Large fraction",
+                       fraction: 0.9,
+               },
+       }
+       for _, tc := range testCases {
+               t.Run(tc.name, func(t *testing.T) {
+                       sampledDF, err := df.Sample(ctx, tc.fraction)
+                       assert.NoError(t, err)
+                       count, err := sampledDF.Count(ctx)
+                       assert.NoError(t, err)
+                       expectedSize := int(100 * tc.fraction)
+                       assert.InDelta(t, expectedSize, count, 
float64(expectedSize), 10)
+                       rows, err := sampledDF.Collect(ctx)
+                       assert.NoError(t, err)
+                       // If sampling without replacement, check for duplicates
+                       seen := make(map[int64]bool)
+                       for _, row := range rows {
+                               value := row.At(0).(int64)
+                               if seen[value] {
+                                       t.Fatal("Found duplicate value when 
sampling without replacement")
+                               }
+                               seen[value] = true
+                       }
+               })
+       }
+}
+
+func TestDataFrame_SampleWithReplacement(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       testCases := []struct {
+               name            string
+               withReplacement bool
+               fraction        float64
+       }{
+               {
+                       name:            "With replacement",
+                       withReplacement: true,
+                       fraction:        0.1,
+               },
+               {
+                       name:            "Without replacement",
+                       withReplacement: false,
+                       fraction:        0.1,
+               },
+       }
+       for _, tc := range testCases {
+               t.Run(tc.name, func(t *testing.T) {
+                       sampledDF, err := df.SampleWithReplacement(ctx, 
tc.withReplacement, tc.fraction)
+                       assert.NoError(t, err)
+                       count, err := sampledDF.Count(ctx)
+                       assert.NoError(t, err)
+                       expectedSize := int(100 * tc.fraction)
+                       assert.InDelta(t, expectedSize, count, 
float64(expectedSize), 10)
+                       rows, err := sampledDF.Collect(ctx)
+                       assert.NoError(t, err)
+                       // If sampling without replacement, check for duplicates
+                       if tc.withReplacement == false {
+                               seen := make(map[int64]bool)
+                               for _, row := range rows {
+                                       value := row.At(0).(int64)
+                                       if seen[value] {
+                                               t.Fatal("Found duplicate value 
when sampling without replacement")
+                                       }
+                                       seen[value] = true
+                               }
+                       }
+               })
+       }
+}
+
+func TestDataFrame_SampleSeed(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       fraction := 0.1
+       seed := int64(17)
+       sampledDF, err := df.SampleWithSeed(ctx, fraction, seed)
+       assert.NoError(t, err)
+       count, err := sampledDF.Count(ctx)
+       assert.NoError(t, err)
+       expectedSize := int(100 * fraction)
+       assert.InDelta(t, expectedSize, count, float64(expectedSize), 10)
+       rows, err := sampledDF.Collect(ctx)
+       assert.NoError(t, err)
+       // If sampling without replacement, check for duplicates
+       seen := make(map[int64]bool)
+       for _, row := range rows {
+               value := row.At(0).(int64)
+               if seen[value] {
+                       t.Fatal("Found duplicate value when sampling without 
replacement")
+               }
+               seen[value] = true
+       }
+       // same seed should return same output
+       sampledDFRepeat, err := df.SampleWithSeed(ctx, fraction, seed)
+       assert.NoError(t, err)
+       count2, err := sampledDFRepeat.Count(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, count, count2)
+       rows2, err := sampledDFRepeat.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, rows, rows2)
+}
+
+func TestDataFrame_SampleWithReplacementSeed(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       fraction := 0.1
+       seed := int64(17)
+       sampledDF, err := df.SampleWithReplacementAndSeed(ctx, true, fraction, 
seed)
+       assert.NoError(t, err)
+       count, err := sampledDF.Count(ctx)
+       assert.NoError(t, err)
+       expectedSize := int(100 * fraction)
+       assert.InDelta(t, expectedSize, count, float64(expectedSize), 10)
+       rows, err := sampledDF.Collect(ctx)
+       assert.NoError(t, err)
+       // same seed should return same output
+       sampledDFRepeat, err := df.SampleWithReplacementAndSeed(ctx, true, 
fraction, seed)
+       assert.NoError(t, err)
+       count2, err := sampledDFRepeat.Count(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, count, count2)
+       rows2, err := sampledDFRepeat.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, rows, rows2)
+}
+
 func TestDataFrame_Unpivot(t *testing.T) {
        ctx, spark := connect()
        data := [][]any{{1, 11, 1.1}, {2, 12, 1.2}}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 1d2b20b..ce67b0e 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -177,6 +177,14 @@ type DataFrame interface {
        Rollup(ctx context.Context, cols ...column.Convertible) *GroupedData
        // SameSemantics returns true if the other DataFrame has the same 
semantics.
        SameSemantics(ctx context.Context, other DataFrame) (bool, error)
+       // Sample samples a data frame without replacement and random seed.
+       Sample(ctx context.Context, fraction float64) (DataFrame, error)
+       // SampleWithReplacement samples a data frame with random seed and 
with/without replacement.
+       SampleWithReplacement(ctx context.Context, withReplacement bool, 
fraction float64) (DataFrame, error)
+       // SampleWithSeed samples a data frame without replacement and given 
seed.
+       SampleWithSeed(ctx context.Context, fraction float64, seed int64) 
(DataFrame, error)
+       // SampleWithReplacementAndSeed samples a data frame with/without 
replacement and given seed.
+       SampleWithReplacementAndSeed(ctx context.Context, withReplacement bool, 
fraction float64, seed int64) (DataFrame, error)
        // Show uses WriteResult to write the data frames to the console output.
        Show(ctx context.Context, numRows int, truncate bool) error
        // Schema returns the schema for the current data frame.
@@ -1401,6 +1409,50 @@ func (df *dataFrameImpl) Summary(ctx context.Context, 
statistics ...string) Data
        return NewDataFrame(df.session, rel)
 }
 
+func (df *dataFrameImpl) Sample(ctx context.Context, fraction float64) 
(DataFrame, error) {
+       return df.sample(ctx, nil, fraction, nil)
+}
+
+func (df *dataFrameImpl) SampleWithReplacement(ctx context.Context, 
withReplacement bool, fraction float64) (DataFrame, error) {
+       return df.sample(ctx, &withReplacement, fraction, nil)
+}
+
+func (df *dataFrameImpl) SampleWithSeed(ctx context.Context, fraction float64, 
seed int64) (DataFrame, error) {
+       return df.sample(ctx, nil, fraction, &seed)
+}
+
+func (df *dataFrameImpl) SampleWithReplacementAndSeed(ctx context.Context, 
withReplacement bool, fraction float64, seed int64) (DataFrame, error) {
+       return df.sample(ctx, &withReplacement, fraction, &seed)
+}
+
+func (df *dataFrameImpl) sample(ctx context.Context, withReplacement *bool, 
fraction float64, seed *int64) (DataFrame, error) {
+       if seed == nil {
+               defaultSeed := rand.Int64()
+               seed = &defaultSeed
+       }
+
+       if withReplacement == nil {
+               defaultWithReplacement := false
+               withReplacement = &defaultWithReplacement
+       }
+
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Sample{
+                       Sample: &proto.Sample{
+                               Input:           df.relation,
+                               LowerBound:      0,
+                               UpperBound:      fraction,
+                               WithReplacement: withReplacement,
+                               Seed:            seed,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel), nil
+}
+
 func (df *dataFrameImpl) Replace(ctx context.Context,
        toReplace []types.PrimitiveTypeLiteral, values 
[]types.PrimitiveTypeLiteral, cols ...string,
 ) (DataFrame, error) {
@@ -1441,7 +1493,6 @@ func (df *dataFrameImpl) Replace(ctx context.Context,
                Common: &proto.RelationCommon{
                        PlanId: newPlanId(),
                },
-
                RelType: &proto.Relation_Replace{
                        Replace: &proto.NAReplace{
                                Input:        df.relation,
diff --git a/spark/sql/dataframestatfunctions.go 
b/spark/sql/dataframestatfunctions.go
index 7c05865..d68bfad 100644
--- a/spark/sql/dataframestatfunctions.go
+++ b/spark/sql/dataframestatfunctions.go
@@ -25,12 +25,36 @@ type DataFrameStatFunctions interface {
        CrossTab(ctx context.Context, col1, col2 string) DataFrame
        FreqItems(ctx context.Context, cols ...string) DataFrame
        FreqItemsWithSupport(ctx context.Context, support float64, cols 
...string) DataFrame
+       Sample(ctx context.Context, fraction float64) (DataFrame, error)
+       SampleWithReplacement(ctx context.Context, withReplacement bool, 
fraction float64) (DataFrame, error)
+       SampleWithSeed(ctx context.Context, fraction float64, seed int64) 
(DataFrame, error)
+       SampleWithReplacementAndSeed(ctx context.Context, withReplacement bool, 
fraction float64, seed int64) (DataFrame, error)
 }
 
 type dataFrameStatFunctionsImpl struct {
        df DataFrame
 }
 
+func (d *dataFrameStatFunctionsImpl) Sample(ctx context.Context, fraction 
float64) (DataFrame, error) {
+       return d.df.Sample(ctx, fraction)
+}
+
+func (d *dataFrameStatFunctionsImpl) SampleWithReplacement(ctx context.Context,
+       withReplacement bool, fraction float64,
+) (DataFrame, error) {
+       return d.df.SampleWithReplacement(ctx, withReplacement, fraction)
+}
+
+func (d *dataFrameStatFunctionsImpl) SampleWithSeed(ctx context.Context, 
fraction float64, seed int64) (DataFrame, error) {
+       return d.df.SampleWithSeed(ctx, fraction, seed)
+}
+
+func (d *dataFrameStatFunctionsImpl) SampleWithReplacementAndSeed(ctx 
context.Context,
+       withReplacement bool, fraction float64, seed int64,
+) (DataFrame, error) {
+       return d.df.SampleWithReplacementAndSeed(ctx, withReplacement, 
fraction, seed)
+}
+
 func (d *dataFrameStatFunctionsImpl) ApproxQuantile(ctx context.Context, 
probabilities []float64,
        relativeError float64, cols ...string,
 ) ([][]float64, error) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to