This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new 1e56e84 Added implementations of FillNa
1e56e84 is described below
commit 1e56e844b96c6e19e91ebbca88b0c95531062325
Author: Alex Ott <[email protected]>
AuthorDate: Tue Dec 31 13:09:36 2024 +0100
Added implementations of FillNa
### What changes were proposed in this pull request?
This PR adds `FillNa` and `FillNaWithValues` functions to `DataFrame`
### Why are the changes needed?
Added missing functions
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Closes #94 from alexott/fillna-implementation.
Authored-by: Alex Ott <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 53 ++++++++++++++++++++++++++++
spark/sql/dataframe.go | 46 ++++++++++++++++++++++++
2 files changed, 99 insertions(+)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index e00b375..232657b 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -31,6 +31,7 @@ import (
"github.com/apache/spark-connect-go/v35/spark/sql"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestDataFrame_Select(t *testing.T) {
@@ -802,3 +803,55 @@ func TestDataFrame_Unpivot(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, int64(4), cnt)
}
+
+func TestDataFrame_FillNa(t *testing.T) {
+ ctx, spark := connect()
+ file, err := os.CreateTemp("", "fillna")
+ defer os.Remove(file.Name())
+ assert.NoError(t, err)
+ defer file.Close()
+ _, err = file.WriteString(`{"id":1,"int":null, "int2": 1}
+{"id":null,"int":12, "int2": null}
+`)
+ assert.NoError(t, err)
+
+ df, err := spark.Read().Format("json").
+ Option("inferSchema", "true").
+ Load(file.Name())
+ assert.NoError(t, err)
+
+ // all columns
+ filled, err := df.FillNa(ctx, types.Int64(10))
+ assert.NoError(t, err)
+ sorted, err := filled.Sort(ctx, functions.Col("id").Asc())
+ assert.NoError(t, err)
+ res, err := sorted.Collect(ctx)
+ assert.NoError(t, err)
+ require.Equal(t, 2, len(res))
+ assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[0].Values())
+ assert.Equal(t, []any{int64(10), int64(12), int64(10)}, res[1].Values())
+
+ // specific columns
+ filled, err = df.FillNa(ctx, types.Int64(10), "int", "int2")
+ assert.NoError(t, err)
+ sorted, err = filled.Sort(ctx, functions.Col("id").Asc())
+ assert.NoError(t, err)
+ res, err = sorted.Collect(ctx)
+ assert.NoError(t, err)
+ require.Equal(t, 2, len(res))
+ assert.Equal(t, []any{nil, int64(12), int64(10)}, res[0].Values())
+ assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[1].Values())
+
+ // specific columns with map
+ filled, err = df.FillNaWithValues(ctx,
map[string]types.PrimitiveTypeLiteral{
+ "int": types.Int64(10), "int2": types.Int64(20),
+ })
+ assert.NoError(t, err)
+ sorted, err = filled.Sort(ctx, functions.Col("id").Asc())
+ assert.NoError(t, err)
+ res, err = sorted.Collect(ctx)
+ assert.NoError(t, err)
+ require.Equal(t, 2, len(res))
+ assert.Equal(t, []any{nil, int64(12), int64(20)}, res[0].Values())
+ assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[1].Values())
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 65f9484..5feb6a3 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -108,6 +108,10 @@ type DataFrame interface {
ExceptAll(ctx context.Context, other DataFrame) DataFrame
// Explain returns the string explain plan for the current DataFrame
according to the explainMode.
Explain(ctx context.Context, explainMode utils.ExplainMode) (string,
error)
+ // FillNa replaces null values with specified value.
+ FillNa(ctx context.Context, value types.PrimitiveTypeLiteral, columns
...string) (DataFrame, error)
+ // FillNaWithValues replaces null values in specified columns (key of
the map) with values.
+ FillNaWithValues(ctx context.Context, values
map[string]types.PrimitiveTypeLiteral) (DataFrame, error)
// Filter filters the data frame by a column condition.
Filter(ctx context.Context, condition column.Convertible) (DataFrame,
error)
// FilterByString filters the data frame by a string condition.
@@ -1426,3 +1430,45 @@ func (df *dataFrameImpl) Unpivot(ctx context.Context,
}
return NewDataFrame(df.session, rel), nil
}
+
+func makeDataframeWithFillNaRelation(df *dataFrameImpl, values
[]*proto.Expression_Literal, columns []string) DataFrame {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_FillNa{
+ FillNa: &proto.NAFill{
+ Input: df.relation,
+ Cols: columns,
+ Values: values,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) FillNa(ctx context.Context, value
types.PrimitiveTypeLiteral, columns ...string) (DataFrame, error) {
+ valueLiteral, err := value.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return makeDataframeWithFillNaRelation(df, []*proto.Expression_Literal{
+ valueLiteral.GetLiteral(),
+ }, columns), nil
+}
+
+func (df *dataFrameImpl) FillNaWithValues(ctx context.Context,
+ values map[string]types.PrimitiveTypeLiteral,
+) (DataFrame, error) {
+ valueLiterals := make([]*proto.Expression_Literal, 0, len(values))
+ columns := make([]string, 0, len(values))
+ for k, v := range values {
+ valueLiteral, err := v.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+ valueLiterals = append(valueLiterals, valueLiteral.GetLiteral())
+ columns = append(columns, k)
+ }
+ return makeDataframeWithFillNaRelation(df, valueLiterals, columns), nil
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]