This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new cc18a0e #58 Add support for `df.Stat()` and `df.ApproxQuantile()`
cc18a0e is described below
commit cc18a0e3c16a1abecc80db03bfb3ce20b875b92c
Author: Martin Grund <[email protected]>
AuthorDate: Thu Jan 2 11:43:01 2025 +0100
#58 Add support for `df.Stat()` and `df.ApproxQuantile()`
### What changes were proposed in this pull request?
Support for `df.ApproxQuantile()` and the aggregated helper `df.Stat().*`
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
New functions
### How was this patch tested?
Added tests.
Closes #101 from grundprinzip/df_stat.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 61 +++++++++++++++++++++++++++
spark/sql/dataframe.go | 50 ++++++++++++++++++++++
spark/sql/dataframestatfunctions.go | 62 ++++++++++++++++++++++++++++
3 files changed, 173 insertions(+)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index 0afe3f2..1290dc9 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -188,6 +188,10 @@ func TestDataFrame_Corr(t *testing.T) {
res, err := df.Corr(ctx, "c1", "c2")
assert.NoError(t, err)
assert.Equal(t, -0.3592106040535498, res)
+
+ res2, err := df.Stat().Corr(ctx, "c1", "c2")
+ assert.NoError(t, err)
+ assert.Equal(t, res, res2)
}
func TestDataFrame_Cov(t *testing.T) {
@@ -205,6 +209,10 @@ func TestDataFrame_Cov(t *testing.T) {
res, err := df.Cov(ctx, "c1", "c2")
assert.NoError(t, err)
assert.Equal(t, -18.0, res)
+
+ res2, err := df.Stat().Cov(ctx, "c1", "c2")
+ assert.NoError(t, err)
+ assert.Equal(t, res, res2)
}
func TestDataFrame_WithColumn(t *testing.T) {
@@ -668,6 +676,16 @@ func TestDataFrame_CrossTab(t *testing.T) {
assert.Equal(t, int64(0), res[0].At(1))
assert.Equal(t, int64(2), res[0].At(2))
assert.Equal(t, int64(0), res[0].At(3))
+
+ df, err = spark.CreateDataFrame(ctx, data, schema)
+ assert.NoError(t, err)
+ df = df.Stat().CrossTab(ctx, "c1", "c2")
+ df, err = df.Sort(ctx, column.OfDF(df, "c1_c2").Asc())
+ assert.NoError(t, err)
+ res2, err := df.Collect(ctx)
+ assert.NoError(t, err)
+
+ assert.Equal(t, res, res2)
}
func TestDataFrame_SameSemantics(t *testing.T) {
@@ -696,6 +714,10 @@ func TestDataFrame_FreqItems(t *testing.T) {
res, err := df.FreqItems(ctx, "id").Collect(ctx)
assert.NoErrorf(t, err, "%+v", err)
assert.Len(t, res, 1)
+
+ res2, err := df.Stat().FreqItems(ctx, "id").Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, res, res2)
}
func TestDataFrame_Config_GetAll(t *testing.T) {
@@ -926,6 +948,45 @@ func TestDataFrame_FillNa(t *testing.T) {
assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[1].Values())
}
+func TestDataFrame_ApproxQuantile(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select id, 1 as id2 from range(100)")
+ assert.NoError(t, err)
+ res, err := df.ApproxQuantile(ctx, []float64{float64(0.5)},
float64(0.1), "id")
+ assert.NoError(t, err)
+ assert.Len(t, res, 1)
+
+ data := [][]any{
+ {"bob", "Developer", 125000, 1},
+ {"mark", "Developer", 108000, 2},
+ {"carl", "Tester", 70000, 2},
+ {"peter", "Developer", 185000, 2},
+ {"jon", "Tester", 65000, 1},
+ {"roman", "Tester", 82000, 2},
+ {"simon", "Developer", 98000, 1},
+ {"eric", "Developer", 144000, 2},
+ {"carlos", "Tester", 75000, 1},
+ {"henry", "Developer", 110000, 1},
+ }
+ schema := types.StructOf(
+ types.NewStructField("Name", types.STRING),
+ types.NewStructField("Role", types.STRING),
+ types.NewStructField("Salary", types.LONG),
+ types.NewStructField("Performance", types.LONG),
+ )
+
+ df, err = spark.CreateDataFrame(ctx, data, schema)
+ assert.NoError(t, err)
+ med, err := df.ApproxQuantile(ctx, []float64{float64(0.5)},
float64(0.25), "Salary")
+
+ assert.NoError(t, err)
+ assert.Len(t, med, 1)
+ assert.GreaterOrEqual(t, med[0][0], 75000.0)
+
+ _, err = df.Stat().ApproxQuantile(ctx, []float64{0.5}, 0.25, "Salary")
+ assert.NoError(t, err)
+}
+
func TestDataFrame_DFNaFunctions(t *testing.T) {
ctx, spark := connect()
data := [][]any{
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 05a403e..2612387 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -45,6 +45,7 @@ type DataFrame interface {
AggWithMap(ctx context.Context, exprs map[string]string) (DataFrame,
error)
// Alias creates a new DataFrame with the specified subquery alias
Alias(ctx context.Context, alias string) DataFrame
+ ApproxQuantile(ctx context.Context, probabilities []float64,
relativeError float64, cols ...string) ([][]float64, error)
// Cache persists the DataFrame with the default storage level.
Cache(ctx context.Context) error
// Coalesce returns a new DataFrame that has exactly numPartitions
partitions.DataFrame
@@ -187,6 +188,7 @@ type DataFrame interface {
SemanticHash(ctx context.Context) (int32, error)
// Sort returns a new DataFrame sorted by the specified columns.
Sort(ctx context.Context, columns ...column.Convertible) (DataFrame,
error)
+ Stat() DataFrameStatFunctions
// Subtract subtracts the other DataFrame from the current DataFrame.
And only returns
// distinct rows.
Subtract(ctx context.Context, other DataFrame) DataFrame
@@ -1544,6 +1546,10 @@ func (df *dataFrameImpl) FillNaWithValues(ctx
context.Context,
return makeDataframeWithFillNaRelation(df, valueLiterals, columns), nil
}
+func (df *dataFrameImpl) Stat() DataFrameStatFunctions {
+ return &dataFrameStatFunctionsImpl{df: df}
+}
+
func (df *dataFrameImpl) Agg(ctx context.Context, cols ...column.Convertible)
(DataFrame, error) {
return df.GroupBy().Agg(ctx, cols...)
}
@@ -1560,6 +1566,50 @@ func (df *dataFrameImpl) AggWithMap(ctx context.Context,
exprs map[string]string
return df.Agg(ctx, funs...)
}
+func (df *dataFrameImpl) ApproxQuantile(ctx context.Context, probabilities
[]float64,
+ relativeError float64, cols ...string,
+) ([][]float64, error) {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_ApproxQuantile{
+ ApproxQuantile: &proto.StatApproxQuantile{
+ Input: df.relation,
+ Probabilities: probabilities,
+ RelativeError: relativeError,
+ Cols: cols,
+ },
+ },
+ }
+ data := NewDataFrame(df.session, rel)
+ rows, err := data.Collect(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ // The result structure is a bit weird here, essentially it returns
exactly one row with
+ // the quantiles.
+ // Inside the row is a list of nested arroys that contain the
quantiles. The first column is the
+ // first nested array, the second column is the second nested array and
so on.
+
+ nested := rows[0].At(0).([]interface{})
+ result := make([][]float64, len(nested))
+ for i := 0; i < len(nested); i++ {
+ tmp := nested[i].([]interface{})
+ result[i] = make([]float64, len(tmp))
+ for j := 0; j < len(tmp); j++ {
+ f, ok := tmp[j].(float64)
+ if !ok {
+ return nil, sparkerrors.WithType(fmt.Errorf(
+ "failed to cast to float64"),
sparkerrors.ExecutionError)
+ }
+ result[i][j] = f
+ }
+ }
+ return result, nil
+}
+
func (df *dataFrameImpl) DropNa(ctx context.Context, subset ...string)
(DataFrame, error) {
rel := &proto.Relation{
Common: &proto.RelationCommon{
diff --git a/spark/sql/dataframestatfunctions.go
b/spark/sql/dataframestatfunctions.go
new file mode 100644
index 0000000..7c05865
--- /dev/null
+++ b/spark/sql/dataframestatfunctions.go
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sql
+
+import "context"
+
+type DataFrameStatFunctions interface {
+ ApproxQuantile(ctx context.Context, probabilities []float64,
relativeError float64, cols ...string) ([][]float64, error)
+ Cov(ctx context.Context, col1, col2 string) (float64, error)
+ Corr(ctx context.Context, col1, col2 string) (float64, error)
+ CorrWithMethod(ctx context.Context, col1, col2 string, method string)
(float64, error)
+ CrossTab(ctx context.Context, col1, col2 string) DataFrame
+ FreqItems(ctx context.Context, cols ...string) DataFrame
+ FreqItemsWithSupport(ctx context.Context, support float64, cols
...string) DataFrame
+}
+
+type dataFrameStatFunctionsImpl struct {
+ df DataFrame
+}
+
+func (d *dataFrameStatFunctionsImpl) ApproxQuantile(ctx context.Context,
probabilities []float64,
+ relativeError float64, cols ...string,
+) ([][]float64, error) {
+ return d.df.ApproxQuantile(ctx, probabilities, relativeError, cols...)
+}
+
+func (d *dataFrameStatFunctionsImpl) Cov(ctx context.Context, col1, col2
string) (float64, error) {
+ return d.df.Cov(ctx, col1, col2)
+}
+
+func (d *dataFrameStatFunctionsImpl) Corr(ctx context.Context, col1, col2
string) (float64, error) {
+ return d.df.Corr(ctx, col1, col2)
+}
+
+func (d *dataFrameStatFunctionsImpl) CorrWithMethod(ctx context.Context, col1,
col2 string, method string) (float64, error) {
+ return d.df.CorrWithMethod(ctx, col1, col2, method)
+}
+
+func (d *dataFrameStatFunctionsImpl) CrossTab(ctx context.Context, col1, col2
string) DataFrame {
+ return d.df.CrossTab(ctx, col1, col2)
+}
+
+func (d *dataFrameStatFunctionsImpl) FreqItems(ctx context.Context, cols
...string) DataFrame {
+ return d.df.FreqItems(ctx, cols...)
+}
+
+func (d *dataFrameStatFunctionsImpl) FreqItemsWithSupport(ctx context.Context,
support float64, cols ...string) DataFrame {
+ return d.df.FreqItemsWithSupport(ctx, support, cols...)
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]