This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 7cc2cbc  #58 Add support for `df.agg()`
7cc2cbc is described below

commit 7cc2cbcb9293f77325fb556f2663c8965479cd10
Author: Martin Grund <[email protected]>
AuthorDate: Thu Jan 2 11:17:21 2025 +0100

    #58 Add support for `df.agg()`
    
    ### What changes were proposed in this pull request?
    Add support for `df.Agg() and `df.AggWithMap()`.
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    New functions
    
    ### How was this patch tested?
    Added test
    
    Closes #100 from grundprinzip/df_agg.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/functions_test.go | 19 +++++++++++++++++++
 spark/sql/dataframe.go                       | 18 ++++++++++++++++++
 spark/sql/group.go                           |  4 ++--
 3 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/internal/tests/integration/functions_test.go 
b/internal/tests/integration/functions_test.go
index b286c02..878ad07 100644
--- a/internal/tests/integration/functions_test.go
+++ b/internal/tests/integration/functions_test.go
@@ -41,6 +41,25 @@ func TestIntegration_BuiltinFunctions(t *testing.T) {
        assert.Equal(t, 10, len(res))
 }
 
+func TestAggregationFunctions_Agg(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select id, 1, 2, 3 from range(100)")
+       assert.NoError(t, err)
+
+       res, err := df.Agg(ctx, functions.Count(functions.Col("id")))
+       assert.NoError(t, err)
+       cnt, err := res.Count(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, int64(1), cnt)
+
+       res, err = df.AggWithMap(ctx, map[string]string{"id": "sum"})
+       assert.NoError(t, err)
+       rows, err := res.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Len(t, rows, 1)
+       assert.Equal(t, int64(4950), rows[0].At(0))
+}
+
 func TestIntegration_ColumnGetItem(t *testing.T) {
        ctx := context.Background()
        spark, err := 
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 3979c34..05a403e 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -41,6 +41,8 @@ type ResultCollector interface {
 type DataFrame interface {
        // PlanId returns the plan id of the data frame.
        PlanId() int64
+       Agg(ctx context.Context, exprs ...column.Convertible) (DataFrame, error)
+       AggWithMap(ctx context.Context, exprs map[string]string) (DataFrame, 
error)
        // Alias creates a new DataFrame with the specified subquery alias
        Alias(ctx context.Context, alias string) DataFrame
        // Cache persists the DataFrame with the default storage level.
@@ -1542,6 +1544,22 @@ func (df *dataFrameImpl) FillNaWithValues(ctx 
context.Context,
        return makeDataframeWithFillNaRelation(df, valueLiterals, columns), nil
 }
 
+func (df *dataFrameImpl) Agg(ctx context.Context, cols ...column.Convertible) 
(DataFrame, error) {
+       return df.GroupBy().Agg(ctx, cols...)
+}
+
+func (df *dataFrameImpl) AggWithMap(ctx context.Context, exprs 
map[string]string) (DataFrame, error) {
+       funs := make([]column.Convertible, 0)
+       for k, v := range exprs {
+               // Convert the column name to a column expression.
+               col := column.OfDF(df, k)
+               // Convert the value string to an unresolved function name.
+               fun := column.NewUnresolvedFunctionWithColumns(v, col)
+               funs = append(funs, fun)
+       }
+       return df.Agg(ctx, funs...)
+}
+
 func (df *dataFrameImpl) DropNa(ctx context.Context, subset ...string) 
(DataFrame, error) {
        rel := &proto.Relation{
                Common: &proto.RelationCommon{
diff --git a/spark/sql/group.go b/spark/sql/group.go
index 975dc50..157e482 100644
--- a/spark/sql/group.go
+++ b/spark/sql/group.go
@@ -37,7 +37,7 @@ type GroupedData struct {
 
 // Agg compute aggregates and returns the result as a DataFrame. The aggegrate 
expressions
 // are passed as column.Column arguments.
-func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Column) 
(DataFrame, error) {
+func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Convertible) 
(DataFrame, error) {
        if len(exprs) == 0 {
                return nil, 
sparkerrors.WithString(sparkerrors.InvalidInputError, "exprs should not be 
empty")
        }
@@ -144,7 +144,7 @@ func (gd *GroupedData) numericAgg(ctx context.Context, name 
string, cols ...stri
                aggCols = numericCols
        }
 
-       finalColumns := make([]column.Column, len(aggCols))
+       finalColumns := make([]column.Convertible, len(aggCols))
        for i, col := range aggCols {
                finalColumns[i] = 
column.NewColumn(column.NewUnresolvedFunctionWithColumns(name, 
functions.Col(col)))
        }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to