This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new cc18a0e  #58 Add support for `df.Stat()` and `df.ApproxQuantile()`
cc18a0e is described below

commit cc18a0e3c16a1abecc80db03bfb3ce20b875b92c
Author: Martin Grund <[email protected]>
AuthorDate: Thu Jan 2 11:43:01 2025 +0100

    #58 Add support for `df.Stat()` and `df.ApproxQuantile()`
    
    ### What changes were proposed in this pull request?
    Support for `df.ApproxQuantile()` and the aggregated helper `df.Stat().*`
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    New functions
    
    ### How was this patch tested?
    Added tests.
    
    Closes #101 from grundprinzip/df_stat.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go | 61 +++++++++++++++++++++++++++
 spark/sql/dataframe.go                       | 50 ++++++++++++++++++++++
 spark/sql/dataframestatfunctions.go          | 62 ++++++++++++++++++++++++++++
 3 files changed, 173 insertions(+)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index 0afe3f2..1290dc9 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -188,6 +188,10 @@ func TestDataFrame_Corr(t *testing.T) {
        res, err := df.Corr(ctx, "c1", "c2")
        assert.NoError(t, err)
        assert.Equal(t, -0.3592106040535498, res)
+
+       res2, err := df.Stat().Corr(ctx, "c1", "c2")
+       assert.NoError(t, err)
+       assert.Equal(t, res, res2)
 }
 
 func TestDataFrame_Cov(t *testing.T) {
@@ -205,6 +209,10 @@ func TestDataFrame_Cov(t *testing.T) {
        res, err := df.Cov(ctx, "c1", "c2")
        assert.NoError(t, err)
        assert.Equal(t, -18.0, res)
+
+       res2, err := df.Stat().Cov(ctx, "c1", "c2")
+       assert.NoError(t, err)
+       assert.Equal(t, res, res2)
 }
 
 func TestDataFrame_WithColumn(t *testing.T) {
@@ -668,6 +676,16 @@ func TestDataFrame_CrossTab(t *testing.T) {
        assert.Equal(t, int64(0), res[0].At(1))
        assert.Equal(t, int64(2), res[0].At(2))
        assert.Equal(t, int64(0), res[0].At(3))
+
+       df, err = spark.CreateDataFrame(ctx, data, schema)
+       assert.NoError(t, err)
+       df = df.Stat().CrossTab(ctx, "c1", "c2")
+       df, err = df.Sort(ctx, column.OfDF(df, "c1_c2").Asc())
+       assert.NoError(t, err)
+       res2, err := df.Collect(ctx)
+       assert.NoError(t, err)
+
+       assert.Equal(t, res, res2)
 }
 
 func TestDataFrame_SameSemantics(t *testing.T) {
@@ -696,6 +714,10 @@ func TestDataFrame_FreqItems(t *testing.T) {
        res, err := df.FreqItems(ctx, "id").Collect(ctx)
        assert.NoErrorf(t, err, "%+v", err)
        assert.Len(t, res, 1)
+
+       res2, err := df.Stat().FreqItems(ctx, "id").Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, res, res2)
 }
 
 func TestDataFrame_Config_GetAll(t *testing.T) {
@@ -926,6 +948,45 @@ func TestDataFrame_FillNa(t *testing.T) {
        assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[1].Values())
 }
 
+func TestDataFrame_ApproxQuantile(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select id, 1 as id2 from range(100)")
+       assert.NoError(t, err)
+       res, err := df.ApproxQuantile(ctx, []float64{float64(0.5)}, 
float64(0.1), "id")
+       assert.NoError(t, err)
+       assert.Len(t, res, 1)
+
+       data := [][]any{
+               {"bob", "Developer", 125000, 1},
+               {"mark", "Developer", 108000, 2},
+               {"carl", "Tester", 70000, 2},
+               {"peter", "Developer", 185000, 2},
+               {"jon", "Tester", 65000, 1},
+               {"roman", "Tester", 82000, 2},
+               {"simon", "Developer", 98000, 1},
+               {"eric", "Developer", 144000, 2},
+               {"carlos", "Tester", 75000, 1},
+               {"henry", "Developer", 110000, 1},
+       }
+       schema := types.StructOf(
+               types.NewStructField("Name", types.STRING),
+               types.NewStructField("Role", types.STRING),
+               types.NewStructField("Salary", types.LONG),
+               types.NewStructField("Performance", types.LONG),
+       )
+
+       df, err = spark.CreateDataFrame(ctx, data, schema)
+       assert.NoError(t, err)
+       med, err := df.ApproxQuantile(ctx, []float64{float64(0.5)}, 
float64(0.25), "Salary")
+
+       assert.NoError(t, err)
+       assert.Len(t, med, 1)
+       assert.GreaterOrEqual(t, med[0][0], 75000.0)
+
+       _, err = df.Stat().ApproxQuantile(ctx, []float64{0.5}, 0.25, "Salary")
+       assert.NoError(t, err)
+}
+
 func TestDataFrame_DFNaFunctions(t *testing.T) {
        ctx, spark := connect()
        data := [][]any{
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 05a403e..2612387 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -45,6 +45,7 @@ type DataFrame interface {
        AggWithMap(ctx context.Context, exprs map[string]string) (DataFrame, 
error)
        // Alias creates a new DataFrame with the specified subquery alias
        Alias(ctx context.Context, alias string) DataFrame
+       ApproxQuantile(ctx context.Context, probabilities []float64, 
relativeError float64, cols ...string) ([][]float64, error)
        // Cache persists the DataFrame with the default storage level.
        Cache(ctx context.Context) error
        // Coalesce returns a new DataFrame that has exactly numPartitions 
partitions.DataFrame
@@ -187,6 +188,7 @@ type DataFrame interface {
        SemanticHash(ctx context.Context) (int32, error)
        // Sort returns a new DataFrame sorted by the specified columns.
        Sort(ctx context.Context, columns ...column.Convertible) (DataFrame, 
error)
+       Stat() DataFrameStatFunctions
        // Subtract subtracts the other DataFrame from the current DataFrame. 
And only returns
        // distinct rows.
        Subtract(ctx context.Context, other DataFrame) DataFrame
@@ -1544,6 +1546,10 @@ func (df *dataFrameImpl) FillNaWithValues(ctx 
context.Context,
        return makeDataframeWithFillNaRelation(df, valueLiterals, columns), nil
 }
 
+func (df *dataFrameImpl) Stat() DataFrameStatFunctions {
+       return &dataFrameStatFunctionsImpl{df: df}
+}
+
 func (df *dataFrameImpl) Agg(ctx context.Context, cols ...column.Convertible) 
(DataFrame, error) {
        return df.GroupBy().Agg(ctx, cols...)
 }
@@ -1560,6 +1566,50 @@ func (df *dataFrameImpl) AggWithMap(ctx context.Context, 
exprs map[string]string
        return df.Agg(ctx, funs...)
 }
 
+func (df *dataFrameImpl) ApproxQuantile(ctx context.Context, probabilities 
[]float64,
+       relativeError float64, cols ...string,
+) ([][]float64, error) {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_ApproxQuantile{
+                       ApproxQuantile: &proto.StatApproxQuantile{
+                               Input:         df.relation,
+                               Probabilities: probabilities,
+                               RelativeError: relativeError,
+                               Cols:          cols,
+                       },
+               },
+       }
+       data := NewDataFrame(df.session, rel)
+       rows, err := data.Collect(ctx)
+       if err != nil {
+               return nil, err
+       }
+
+       // The result structure is a bit weird here, essentially it returns 
exactly one row with
+       // the quantiles.
+       // Inside the row is a list of nested arroys that contain the 
quantiles. The first column is the
+       // first nested array, the second column is the second nested array and 
so on.
+
+       nested := rows[0].At(0).([]interface{})
+       result := make([][]float64, len(nested))
+       for i := 0; i < len(nested); i++ {
+               tmp := nested[i].([]interface{})
+               result[i] = make([]float64, len(tmp))
+               for j := 0; j < len(tmp); j++ {
+                       f, ok := tmp[j].(float64)
+                       if !ok {
+                               return nil, sparkerrors.WithType(fmt.Errorf(
+                                       "failed to cast to float64"), 
sparkerrors.ExecutionError)
+                       }
+                       result[i][j] = f
+               }
+       }
+       return result, nil
+}
+
 func (df *dataFrameImpl) DropNa(ctx context.Context, subset ...string) 
(DataFrame, error) {
        rel := &proto.Relation{
                Common: &proto.RelationCommon{
diff --git a/spark/sql/dataframestatfunctions.go 
b/spark/sql/dataframestatfunctions.go
new file mode 100644
index 0000000..7c05865
--- /dev/null
+++ b/spark/sql/dataframestatfunctions.go
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sql
+
+import "context"
+
+type DataFrameStatFunctions interface {
+       ApproxQuantile(ctx context.Context, probabilities []float64, 
relativeError float64, cols ...string) ([][]float64, error)
+       Cov(ctx context.Context, col1, col2 string) (float64, error)
+       Corr(ctx context.Context, col1, col2 string) (float64, error)
+       CorrWithMethod(ctx context.Context, col1, col2 string, method string) 
(float64, error)
+       CrossTab(ctx context.Context, col1, col2 string) DataFrame
+       FreqItems(ctx context.Context, cols ...string) DataFrame
+       FreqItemsWithSupport(ctx context.Context, support float64, cols 
...string) DataFrame
+}
+
+type dataFrameStatFunctionsImpl struct {
+       df DataFrame
+}
+
+func (d *dataFrameStatFunctionsImpl) ApproxQuantile(ctx context.Context, 
probabilities []float64,
+       relativeError float64, cols ...string,
+) ([][]float64, error) {
+       return d.df.ApproxQuantile(ctx, probabilities, relativeError, cols...)
+}
+
+func (d *dataFrameStatFunctionsImpl) Cov(ctx context.Context, col1, col2 
string) (float64, error) {
+       return d.df.Cov(ctx, col1, col2)
+}
+
+func (d *dataFrameStatFunctionsImpl) Corr(ctx context.Context, col1, col2 
string) (float64, error) {
+       return d.df.Corr(ctx, col1, col2)
+}
+
+func (d *dataFrameStatFunctionsImpl) CorrWithMethod(ctx context.Context, col1, 
col2 string, method string) (float64, error) {
+       return d.df.CorrWithMethod(ctx, col1, col2, method)
+}
+
+func (d *dataFrameStatFunctionsImpl) CrossTab(ctx context.Context, col1, col2 
string) DataFrame {
+       return d.df.CrossTab(ctx, col1, col2)
+}
+
+func (d *dataFrameStatFunctionsImpl) FreqItems(ctx context.Context, cols 
...string) DataFrame {
+       return d.df.FreqItems(ctx, cols...)
+}
+
+func (d *dataFrameStatFunctionsImpl) FreqItemsWithSupport(ctx context.Context, 
support float64, cols ...string) DataFrame {
+       return d.df.FreqItemsWithSupport(ctx, support, cols...)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to