This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new f9398aa Add Sample functionality in DataFrame
f9398aa is described below
commit f9398aa124b5effaed70f9b6b804cb82b94d40fb
Author: vatsal <[email protected]>
AuthorDate: Thu Jan 2 20:30:37 2025 +0100
Add Sample functionality in DataFrame
Signed-off-by: Vatsal <vatsal.v.anandgmail.com>
### What changes were proposed in this pull request?
This PR adds support for the `sample` DF function.
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
new DF functions
### How was this patch tested?
Added e2e tests
Closes #84 from imvtsl/issue.
Lead-authored-by: vatsal <[email protected]>
Co-authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 144 +++++++++++++++++++++++++++
spark/sql/dataframe.go | 53 +++++++++-
spark/sql/dataframestatfunctions.go | 24 +++++
3 files changed, 220 insertions(+), 1 deletion(-)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index 6220b8e..720cc98 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -805,6 +805,150 @@ func TestDataFrame_WithOption(t *testing.T) {
assert.Equal(t, int64(10), c)
}
+func TestDataFrame_Sample(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(100)")
+ assert.NoError(t, err)
+ testCases := []struct {
+ name string
+ fraction float64
+ }{
+ {
+ name: "Default behavior",
+ fraction: 0.1,
+ },
+ {
+ name: "Large fraction",
+ fraction: 0.9,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ sampledDF, err := df.Sample(ctx, tc.fraction)
+ assert.NoError(t, err)
+ count, err := sampledDF.Count(ctx)
+ assert.NoError(t, err)
+ expectedSize := int(100 * tc.fraction)
+ assert.InDelta(t, expectedSize, count,
float64(expectedSize), 10)
+ rows, err := sampledDF.Collect(ctx)
+ assert.NoError(t, err)
+ // If sampling without replacement, check for duplicates
+ seen := make(map[int64]bool)
+ for _, row := range rows {
+ value := row.At(0).(int64)
+ if seen[value] {
+ t.Fatal("Found duplicate value when
sampling without replacement")
+ }
+ seen[value] = true
+ }
+ })
+ }
+}
+
+func TestDataFrame_SampleWithReplacement(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(100)")
+ assert.NoError(t, err)
+ testCases := []struct {
+ name string
+ withReplacement bool
+ fraction float64
+ }{
+ {
+ name: "With replacement",
+ withReplacement: true,
+ fraction: 0.1,
+ },
+ {
+ name: "Without replacement",
+ withReplacement: false,
+ fraction: 0.1,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ sampledDF, err := df.SampleWithReplacement(ctx,
tc.withReplacement, tc.fraction)
+ assert.NoError(t, err)
+ count, err := sampledDF.Count(ctx)
+ assert.NoError(t, err)
+ expectedSize := int(100 * tc.fraction)
+ assert.InDelta(t, expectedSize, count,
float64(expectedSize), 10)
+ rows, err := sampledDF.Collect(ctx)
+ assert.NoError(t, err)
+ // If sampling without replacement, check for duplicates
+ if tc.withReplacement == false {
+ seen := make(map[int64]bool)
+ for _, row := range rows {
+ value := row.At(0).(int64)
+ if seen[value] {
+ t.Fatal("Found duplicate value
when sampling without replacement")
+ }
+ seen[value] = true
+ }
+ }
+ })
+ }
+}
+
+func TestDataFrame_SampleSeed(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(100)")
+ assert.NoError(t, err)
+ fraction := 0.1
+ seed := int64(17)
+ sampledDF, err := df.SampleWithSeed(ctx, fraction, seed)
+ assert.NoError(t, err)
+ count, err := sampledDF.Count(ctx)
+ assert.NoError(t, err)
+ expectedSize := int(100 * fraction)
+ assert.InDelta(t, expectedSize, count, float64(expectedSize), 10)
+ rows, err := sampledDF.Collect(ctx)
+ assert.NoError(t, err)
+ // If sampling without replacement, check for duplicates
+ seen := make(map[int64]bool)
+ for _, row := range rows {
+ value := row.At(0).(int64)
+ if seen[value] {
+ t.Fatal("Found duplicate value when sampling without
replacement")
+ }
+ seen[value] = true
+ }
+ // same seed should return same output
+ sampledDFRepeat, err := df.SampleWithSeed(ctx, fraction, seed)
+ assert.NoError(t, err)
+ count2, err := sampledDFRepeat.Count(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, count, count2)
+ rows2, err := sampledDFRepeat.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, rows, rows2)
+}
+
+func TestDataFrame_SampleWithReplacementSeed(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(100)")
+ assert.NoError(t, err)
+ fraction := 0.1
+ seed := int64(17)
+ sampledDF, err := df.SampleWithReplacementAndSeed(ctx, true, fraction,
seed)
+ assert.NoError(t, err)
+ count, err := sampledDF.Count(ctx)
+ assert.NoError(t, err)
+ expectedSize := int(100 * fraction)
+ assert.InDelta(t, expectedSize, count, float64(expectedSize), 10)
+ rows, err := sampledDF.Collect(ctx)
+ assert.NoError(t, err)
+ // same seed should return same output
+ sampledDFRepeat, err := df.SampleWithReplacementAndSeed(ctx, true,
fraction, seed)
+ assert.NoError(t, err)
+ count2, err := sampledDFRepeat.Count(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, count, count2)
+ rows2, err := sampledDFRepeat.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, rows, rows2)
+}
+
func TestDataFrame_Unpivot(t *testing.T) {
ctx, spark := connect()
data := [][]any{{1, 11, 1.1}, {2, 12, 1.2}}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 1d2b20b..ce67b0e 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -177,6 +177,14 @@ type DataFrame interface {
Rollup(ctx context.Context, cols ...column.Convertible) *GroupedData
// SameSemantics returns true if the other DataFrame has the same
semantics.
SameSemantics(ctx context.Context, other DataFrame) (bool, error)
+ // Sample samples a data frame without replacement and random seed.
+ Sample(ctx context.Context, fraction float64) (DataFrame, error)
+ // SampleWithReplacement samples a data frame with random seed and
with/without replacement.
+ SampleWithReplacement(ctx context.Context, withReplacement bool,
fraction float64) (DataFrame, error)
+ // SampleWithSeed samples a data frame without replacement and given
seed.
+ SampleWithSeed(ctx context.Context, fraction float64, seed int64)
(DataFrame, error)
+ // SampleWithReplacementAndSeed samples a data frame with/without
replacement and given seed.
+ SampleWithReplacementAndSeed(ctx context.Context, withReplacement bool,
fraction float64, seed int64) (DataFrame, error)
// Show uses WriteResult to write the data frames to the console output.
Show(ctx context.Context, numRows int, truncate bool) error
// Schema returns the schema for the current data frame.
@@ -1401,6 +1409,50 @@ func (df *dataFrameImpl) Summary(ctx context.Context,
statistics ...string) Data
return NewDataFrame(df.session, rel)
}
+func (df *dataFrameImpl) Sample(ctx context.Context, fraction float64)
(DataFrame, error) {
+ return df.sample(ctx, nil, fraction, nil)
+}
+
+func (df *dataFrameImpl) SampleWithReplacement(ctx context.Context,
withReplacement bool, fraction float64) (DataFrame, error) {
+ return df.sample(ctx, &withReplacement, fraction, nil)
+}
+
+func (df *dataFrameImpl) SampleWithSeed(ctx context.Context, fraction float64,
seed int64) (DataFrame, error) {
+ return df.sample(ctx, nil, fraction, &seed)
+}
+
+func (df *dataFrameImpl) SampleWithReplacementAndSeed(ctx context.Context,
withReplacement bool, fraction float64, seed int64) (DataFrame, error) {
+ return df.sample(ctx, &withReplacement, fraction, &seed)
+}
+
+func (df *dataFrameImpl) sample(ctx context.Context, withReplacement *bool,
fraction float64, seed *int64) (DataFrame, error) {
+ if seed == nil {
+ defaultSeed := rand.Int64()
+ seed = &defaultSeed
+ }
+
+ if withReplacement == nil {
+ defaultWithReplacement := false
+ withReplacement = &defaultWithReplacement
+ }
+
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Sample{
+ Sample: &proto.Sample{
+ Input: df.relation,
+ LowerBound: 0,
+ UpperBound: fraction,
+ WithReplacement: withReplacement,
+ Seed: seed,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel), nil
+}
+
func (df *dataFrameImpl) Replace(ctx context.Context,
toReplace []types.PrimitiveTypeLiteral, values
[]types.PrimitiveTypeLiteral, cols ...string,
) (DataFrame, error) {
@@ -1441,7 +1493,6 @@ func (df *dataFrameImpl) Replace(ctx context.Context,
Common: &proto.RelationCommon{
PlanId: newPlanId(),
},
-
RelType: &proto.Relation_Replace{
Replace: &proto.NAReplace{
Input: df.relation,
diff --git a/spark/sql/dataframestatfunctions.go
b/spark/sql/dataframestatfunctions.go
index 7c05865..d68bfad 100644
--- a/spark/sql/dataframestatfunctions.go
+++ b/spark/sql/dataframestatfunctions.go
@@ -25,12 +25,36 @@ type DataFrameStatFunctions interface {
CrossTab(ctx context.Context, col1, col2 string) DataFrame
FreqItems(ctx context.Context, cols ...string) DataFrame
FreqItemsWithSupport(ctx context.Context, support float64, cols
...string) DataFrame
+ Sample(ctx context.Context, fraction float64) (DataFrame, error)
+ SampleWithReplacement(ctx context.Context, withReplacement bool,
fraction float64) (DataFrame, error)
+ SampleWithSeed(ctx context.Context, fraction float64, seed int64)
(DataFrame, error)
+ SampleWithReplacementAndSeed(ctx context.Context, withReplacement bool,
fraction float64, seed int64) (DataFrame, error)
}
type dataFrameStatFunctionsImpl struct {
df DataFrame
}
+func (d *dataFrameStatFunctionsImpl) Sample(ctx context.Context, fraction
float64) (DataFrame, error) {
+ return d.df.Sample(ctx, fraction)
+}
+
+func (d *dataFrameStatFunctionsImpl) SampleWithReplacement(ctx context.Context,
+ withReplacement bool, fraction float64,
+) (DataFrame, error) {
+ return d.df.SampleWithReplacement(ctx, withReplacement, fraction)
+}
+
+func (d *dataFrameStatFunctionsImpl) SampleWithSeed(ctx context.Context,
fraction float64, seed int64) (DataFrame, error) {
+ return d.df.SampleWithSeed(ctx, fraction, seed)
+}
+
+func (d *dataFrameStatFunctionsImpl) SampleWithReplacementAndSeed(ctx
context.Context,
+ withReplacement bool, fraction float64, seed int64,
+) (DataFrame, error) {
+ return d.df.SampleWithReplacementAndSeed(ctx, withReplacement,
fraction, seed)
+}
+
func (d *dataFrameStatFunctionsImpl) ApproxQuantile(ctx context.Context,
probabilities []float64,
relativeError float64, cols ...string,
) ([][]float64, error) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]