This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ae6f72  #84 adding `Replace()`
1ae6f72 is described below

commit 1ae6f720f6fa36c9c2a23f3edd0872123c86761a
Author: Martin Grund <[email protected]>
AuthorDate: Tue Dec 31 14:36:56 2024 +0100

    #84 adding `Replace()`
    
    ### What changes were proposed in this pull request?
    Adding support for the basic version of `Replace()`
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    Adding new functions
    
    ### How was this patch tested?
    Added IT
    
    Closes #97 from grundprinzip/replace_84.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go | 70 ++++++++++++++++++++++++++++
 spark/sql/dataframe.go                       | 62 ++++++++++++++++++++++++
 2 files changed, 132 insertions(+)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index 232657b..5a61d48 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -804,6 +804,76 @@ func TestDataFrame_Unpivot(t *testing.T) {
        assert.Equal(t, int64(4), cnt)
 }
 
+func TestDataFrame_Replace(t *testing.T) {
+       ctx, spark := connect()
+       data := [][]any{
+               {10, 80, "Alice"},
+               {5, nil, "Bob"},
+               {nil, 10, "Tom"},
+               {nil, nil, nil},
+       }
+       schema := types.StructOf(
+               types.NewStructField("age", types.INTEGER),
+               types.NewStructField("height", types.INTEGER),
+               types.NewStructField("name", types.STRING),
+       )
+       df, err := spark.CreateDataFrame(ctx, data, schema)
+       assert.NoError(t, err)
+
+       res, err := df.Replace(ctx,
+               []types.PrimitiveTypeLiteral{types.Int32(10)},
+               []types.PrimitiveTypeLiteral{types.Int32(20)},
+       )
+       assert.NoError(t, err)
+
+       cnt, err := res.Count(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, int64(4), cnt)
+
+       rows, err := res.Collect(ctx)
+       assert.NoError(t, err)
+
+       assert.Equal(t, int32(20), rows[0].At(0))
+       assert.Equal(t, int32(20), rows[2].At(1))
+
+       res, err = df.Replace(ctx,
+               []types.PrimitiveTypeLiteral{types.Int32(10)},
+               []types.PrimitiveTypeLiteral{types.Int32Nil},
+       )
+       assert.NoError(t, err)
+
+       rows, err = res.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Nil(t, rows[0].At(0))
+}
+
+func TestDataFrame_ReplaceWithColumn(t *testing.T) {
+       ctx, spark := connect()
+       data := [][]any{
+               {10, 80, "Alice"},
+               {5, nil, "Bob"},
+               {nil, 10, "Tom"},
+               {nil, nil, nil},
+       }
+       schema := types.StructOf(
+               types.NewStructField("age", types.INTEGER),
+               types.NewStructField("height", types.INTEGER),
+               types.NewStructField("name", types.STRING),
+       )
+       df, err := spark.CreateDataFrame(ctx, data, schema)
+       assert.NoError(t, err)
+
+       res, err := df.Replace(ctx, 
[]types.PrimitiveTypeLiteral{types.Int32(10)},
+               []types.PrimitiveTypeLiteral{types.Int32(20)}, "age")
+       assert.NoError(t, err)
+
+       rows, err := res.Collect(ctx)
+       assert.NoError(t, err)
+       // Should only repalce the age column but not the height column
+       assert.Equal(t, int32(20), rows[0].At(0))
+       assert.Equal(t, int32(10), rows[2].At(1))
+}
+
 func TestDataFrame_FillNa(t *testing.T) {
        ctx, spark := connect()
        file, err := os.CreateTemp("", "fillna")
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 5feb6a3..30b5293 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -150,6 +150,16 @@ type DataFrame interface {
        Repartition(ctx context.Context, numPartitions int, columns []string) 
(DataFrame, error)
        // RepartitionByRange re-partitions a data frame by range partition.
        RepartitionByRange(ctx context.Context, numPartitions int, columns 
...column.Convertible) (DataFrame, error)
+       // Replace Returns a new DataFrame` replacing a value with another 
value.
+       // Values toReplace and Values must have the same type and can only be 
numerics, booleans,
+       // or strings. Value can have None. When replacing, the new value will 
be cast
+       // to the type of the existing column.
+       //
+       // For numeric replacements all values to be replaced should have unique
+       // floating point representation. If cols is set allows to specify a 
subset of columns to
+       // perform the replacement.
+       Replace(ctx context.Context, toReplace []types.PrimitiveTypeLiteral,
+               values []types.PrimitiveTypeLiteral, cols ...string) 
(DataFrame, error)
        // Rollup creates a multi-dimensional rollup for the current DataFrame 
using
        // the specified columns, so we can run aggregation on them.
        Rollup(ctx context.Context, cols ...column.Convertible) *GroupedData
@@ -1378,6 +1388,58 @@ func (df *dataFrameImpl) Summary(ctx context.Context, 
statistics ...string) Data
        return NewDataFrame(df.session, rel)
 }
 
+func (df *dataFrameImpl) Replace(ctx context.Context,
+       toReplace []types.PrimitiveTypeLiteral, values 
[]types.PrimitiveTypeLiteral, cols ...string,
+) (DataFrame, error) {
+       if len(toReplace) != len(values) {
+               return nil, sparkerrors.WithType(fmt.Errorf(
+                       "toReplace and values must have the same length"), 
sparkerrors.InvalidArgumentError)
+       }
+
+       toReplaceExprs := make([]*proto.Expression, 0, len(toReplace))
+       for _, c := range toReplace {
+               expr, err := c.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               toReplaceExprs = append(toReplaceExprs, expr)
+       }
+
+       valuesExprs := make([]*proto.Expression, 0, len(values))
+       for _, c := range values {
+               expr, err := c.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               valuesExprs = append(valuesExprs, expr)
+       }
+
+       // Create a list of NAReplace expressions.
+       replacements := make([]*proto.NAReplace_Replacement, 0, len(toReplace))
+       for i := 0; i < len(toReplace); i++ {
+               replacement := &proto.NAReplace_Replacement{
+                       OldValue: toReplaceExprs[i].GetLiteral(),
+                       NewValue: valuesExprs[i].GetLiteral(),
+               }
+               replacements = append(replacements, replacement)
+       }
+
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+
+               RelType: &proto.Relation_Replace{
+                       Replace: &proto.NAReplace{
+                               Input:        df.relation,
+                               Replacements: replacements,
+                               Cols:         cols,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel), nil
+}
+
 func (df *dataFrameImpl) Melt(ctx context.Context,
        ids []column.Convertible,
        values []column.Convertible,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to