This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new 1ae6f72 #84 adding `Replace()`
1ae6f72 is described below
commit 1ae6f720f6fa36c9c2a23f3edd0872123c86761a
Author: Martin Grund <[email protected]>
AuthorDate: Tue Dec 31 14:36:56 2024 +0100
#84 adding `Replace()`
### What changes were proposed in this pull request?
Adding support for the basic version of `Replace()`
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
Adding new functions
### How was this patch tested?
Added IT
Closes #97 from grundprinzip/replace_84.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 70 ++++++++++++++++++++++++++++
spark/sql/dataframe.go | 62 ++++++++++++++++++++++++
2 files changed, 132 insertions(+)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index 232657b..5a61d48 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -804,6 +804,76 @@ func TestDataFrame_Unpivot(t *testing.T) {
assert.Equal(t, int64(4), cnt)
}
+func TestDataFrame_Replace(t *testing.T) {
+ ctx, spark := connect()
+ data := [][]any{
+ {10, 80, "Alice"},
+ {5, nil, "Bob"},
+ {nil, 10, "Tom"},
+ {nil, nil, nil},
+ }
+ schema := types.StructOf(
+ types.NewStructField("age", types.INTEGER),
+ types.NewStructField("height", types.INTEGER),
+ types.NewStructField("name", types.STRING),
+ )
+ df, err := spark.CreateDataFrame(ctx, data, schema)
+ assert.NoError(t, err)
+
+ res, err := df.Replace(ctx,
+ []types.PrimitiveTypeLiteral{types.Int32(10)},
+ []types.PrimitiveTypeLiteral{types.Int32(20)},
+ )
+ assert.NoError(t, err)
+
+ cnt, err := res.Count(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(4), cnt)
+
+ rows, err := res.Collect(ctx)
+ assert.NoError(t, err)
+
+ assert.Equal(t, int32(20), rows[0].At(0))
+ assert.Equal(t, int32(20), rows[2].At(1))
+
+ res, err = df.Replace(ctx,
+ []types.PrimitiveTypeLiteral{types.Int32(10)},
+ []types.PrimitiveTypeLiteral{types.Int32Nil},
+ )
+ assert.NoError(t, err)
+
+ rows, err = res.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Nil(t, rows[0].At(0))
+}
+
+func TestDataFrame_ReplaceWithColumn(t *testing.T) {
+ ctx, spark := connect()
+ data := [][]any{
+ {10, 80, "Alice"},
+ {5, nil, "Bob"},
+ {nil, 10, "Tom"},
+ {nil, nil, nil},
+ }
+ schema := types.StructOf(
+ types.NewStructField("age", types.INTEGER),
+ types.NewStructField("height", types.INTEGER),
+ types.NewStructField("name", types.STRING),
+ )
+ df, err := spark.CreateDataFrame(ctx, data, schema)
+ assert.NoError(t, err)
+
+ res, err := df.Replace(ctx,
[]types.PrimitiveTypeLiteral{types.Int32(10)},
+ []types.PrimitiveTypeLiteral{types.Int32(20)}, "age")
+ assert.NoError(t, err)
+
+ rows, err := res.Collect(ctx)
+ assert.NoError(t, err)
+ // Should only repalce the age column but not the height column
+ assert.Equal(t, int32(20), rows[0].At(0))
+ assert.Equal(t, int32(10), rows[2].At(1))
+}
+
func TestDataFrame_FillNa(t *testing.T) {
ctx, spark := connect()
file, err := os.CreateTemp("", "fillna")
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 5feb6a3..30b5293 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -150,6 +150,16 @@ type DataFrame interface {
Repartition(ctx context.Context, numPartitions int, columns []string)
(DataFrame, error)
// RepartitionByRange re-partitions a data frame by range partition.
RepartitionByRange(ctx context.Context, numPartitions int, columns
...column.Convertible) (DataFrame, error)
+ // Replace Returns a new DataFrame` replacing a value with another
value.
+ // Values toReplace and Values must have the same type and can only be
numerics, booleans,
+ // or strings. Value can have None. When replacing, the new value will
be cast
+ // to the type of the existing column.
+ //
+ // For numeric replacements all values to be replaced should have unique
+ // floating point representation. If cols is set allows to specify a
subset of columns to
+ // perform the replacement.
+ Replace(ctx context.Context, toReplace []types.PrimitiveTypeLiteral,
+ values []types.PrimitiveTypeLiteral, cols ...string)
(DataFrame, error)
// Rollup creates a multi-dimensional rollup for the current DataFrame
using
// the specified columns, so we can run aggregation on them.
Rollup(ctx context.Context, cols ...column.Convertible) *GroupedData
@@ -1378,6 +1388,58 @@ func (df *dataFrameImpl) Summary(ctx context.Context,
statistics ...string) Data
return NewDataFrame(df.session, rel)
}
+func (df *dataFrameImpl) Replace(ctx context.Context,
+ toReplace []types.PrimitiveTypeLiteral, values
[]types.PrimitiveTypeLiteral, cols ...string,
+) (DataFrame, error) {
+ if len(toReplace) != len(values) {
+ return nil, sparkerrors.WithType(fmt.Errorf(
+ "toReplace and values must have the same length"),
sparkerrors.InvalidArgumentError)
+ }
+
+ toReplaceExprs := make([]*proto.Expression, 0, len(toReplace))
+ for _, c := range toReplace {
+ expr, err := c.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+ toReplaceExprs = append(toReplaceExprs, expr)
+ }
+
+ valuesExprs := make([]*proto.Expression, 0, len(values))
+ for _, c := range values {
+ expr, err := c.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+ valuesExprs = append(valuesExprs, expr)
+ }
+
+ // Create a list of NAReplace expressions.
+ replacements := make([]*proto.NAReplace_Replacement, 0, len(toReplace))
+ for i := 0; i < len(toReplace); i++ {
+ replacement := &proto.NAReplace_Replacement{
+ OldValue: toReplaceExprs[i].GetLiteral(),
+ NewValue: valuesExprs[i].GetLiteral(),
+ }
+ replacements = append(replacements, replacement)
+ }
+
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+
+ RelType: &proto.Relation_Replace{
+ Replace: &proto.NAReplace{
+ Input: df.relation,
+ Replacements: replacements,
+ Cols: cols,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel), nil
+}
+
func (df *dataFrameImpl) Melt(ctx context.Context,
ids []column.Convertible,
values []column.Convertible,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]