This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new 7cc2cbc #58 Add support for `df.agg()`
7cc2cbc is described below
commit 7cc2cbcb9293f77325fb556f2663c8965479cd10
Author: Martin Grund <[email protected]>
AuthorDate: Thu Jan 2 11:17:21 2025 +0100
#58 Add support for `df.agg()`
### What changes were proposed in this pull request?
Add support for `df.Agg() and `df.AggWithMap()`.
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
New functions
### How was this patch tested?
Added test
Closes #100 from grundprinzip/df_agg.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/functions_test.go | 19 +++++++++++++++++++
spark/sql/dataframe.go | 18 ++++++++++++++++++
spark/sql/group.go | 4 ++--
3 files changed, 39 insertions(+), 2 deletions(-)
diff --git a/internal/tests/integration/functions_test.go
b/internal/tests/integration/functions_test.go
index b286c02..878ad07 100644
--- a/internal/tests/integration/functions_test.go
+++ b/internal/tests/integration/functions_test.go
@@ -41,6 +41,25 @@ func TestIntegration_BuiltinFunctions(t *testing.T) {
assert.Equal(t, 10, len(res))
}
+func TestAggregationFunctions_Agg(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select id, 1, 2, 3 from range(100)")
+ assert.NoError(t, err)
+
+ res, err := df.Agg(ctx, functions.Count(functions.Col("id")))
+ assert.NoError(t, err)
+ cnt, err := res.Count(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), cnt)
+
+ res, err = df.AggWithMap(ctx, map[string]string{"id": "sum"})
+ assert.NoError(t, err)
+ rows, err := res.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, rows, 1)
+ assert.Equal(t, int64(4950), rows[0].At(0))
+}
+
func TestIntegration_ColumnGetItem(t *testing.T) {
ctx := context.Background()
spark, err :=
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 3979c34..05a403e 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -41,6 +41,8 @@ type ResultCollector interface {
type DataFrame interface {
// PlanId returns the plan id of the data frame.
PlanId() int64
+ Agg(ctx context.Context, exprs ...column.Convertible) (DataFrame, error)
+ AggWithMap(ctx context.Context, exprs map[string]string) (DataFrame,
error)
// Alias creates a new DataFrame with the specified subquery alias
Alias(ctx context.Context, alias string) DataFrame
// Cache persists the DataFrame with the default storage level.
@@ -1542,6 +1544,22 @@ func (df *dataFrameImpl) FillNaWithValues(ctx
context.Context,
return makeDataframeWithFillNaRelation(df, valueLiterals, columns), nil
}
+func (df *dataFrameImpl) Agg(ctx context.Context, cols ...column.Convertible)
(DataFrame, error) {
+ return df.GroupBy().Agg(ctx, cols...)
+}
+
+func (df *dataFrameImpl) AggWithMap(ctx context.Context, exprs
map[string]string) (DataFrame, error) {
+ funs := make([]column.Convertible, 0)
+ for k, v := range exprs {
+ // Convert the column name to a column expression.
+ col := column.OfDF(df, k)
+ // Convert the value string to an unresolved function name.
+ fun := column.NewUnresolvedFunctionWithColumns(v, col)
+ funs = append(funs, fun)
+ }
+ return df.Agg(ctx, funs...)
+}
+
func (df *dataFrameImpl) DropNa(ctx context.Context, subset ...string)
(DataFrame, error) {
rel := &proto.Relation{
Common: &proto.RelationCommon{
diff --git a/spark/sql/group.go b/spark/sql/group.go
index 975dc50..157e482 100644
--- a/spark/sql/group.go
+++ b/spark/sql/group.go
@@ -37,7 +37,7 @@ type GroupedData struct {
// Agg compute aggregates and returns the result as a DataFrame. The aggegrate
expressions
// are passed as column.Column arguments.
-func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Column)
(DataFrame, error) {
+func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Convertible)
(DataFrame, error) {
if len(exprs) == 0 {
return nil,
sparkerrors.WithString(sparkerrors.InvalidInputError, "exprs should not be
empty")
}
@@ -144,7 +144,7 @@ func (gd *GroupedData) numericAgg(ctx context.Context, name
string, cols ...stri
aggCols = numericCols
}
- finalColumns := make([]column.Column, len(aggCols))
+ finalColumns := make([]column.Convertible, len(aggCols))
for i, col := range aggCols {
finalColumns[i] =
column.NewColumn(column.NewUnresolvedFunctionWithColumns(name,
functions.Col(col)))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]