This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch grundprinzip/73_with_columns
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git

commit d37b0813b21729f268695b45c06153e39c63f21e
Author: Martin Grund <[email protected]>
AuthorDate: Thu Oct 3 04:59:27 2024 -0700

    withColumns
---
 internal/tests/integration/dataframe_test.go |  6 ++----
 spark/sql/column/column.go                   | 27 +++++++++++++++++++++++
 spark/sql/column/column_test.go              | 32 ++++++++++++++++++++++++++++
 spark/sql/dataframe.go                       | 16 ++++++--------
 4 files changed, 68 insertions(+), 13 deletions(-)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index bb9e2aa..ae9e06b 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -224,10 +224,8 @@ func TestDataFrame_WithColumns(t *testing.T) {
        ctx, spark := connect()
        df, err := spark.Sql(ctx, "select * from range(10)")
        assert.NoError(t, err)
-       df, err = df.WithColumns(ctx, map[string]column.Convertible{
-               "newCol1": functions.Lit(1),
-               "newCol2": functions.Lit(2),
-       })
+       df, err = df.WithColumns(ctx, column.WithAlias("newCol1", 
functions.Lit(1)),
+               column.WithAlias("newCol2", functions.Lit(2)))
        assert.NoError(t, err)
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go
index b2de941..ce70189 100644
--- a/spark/sql/column/column.go
+++ b/spark/sql/column/column.go
@@ -108,3 +108,30 @@ func OfDFWithRegex(df SchemaDataFrame, colRegex string) 
Column {
        planId := df.PlanId()
        return NewColumn(&unresolvedRegex{colRegex, &planId})
 }
+
+type Alias struct {
+       Name string
+       Col  Convertible
+}
+
+func (a Alias) ToProto(ctx context.Context) (*proto.Expression, error) {
+       col, err := a.Col.ToProto(ctx)
+       if err != nil {
+               return nil, err
+       }
+       return &proto.Expression{
+               ExprType: &proto.Expression_Alias_{
+                       Alias: &proto.Expression_Alias{
+                               Expr: col,
+                               Name: []string{a.Name},
+                       },
+               },
+       }, nil
+}
+
+func WithAlias(name string, col Convertible) Alias {
+       return Alias{
+               Name: name,
+               Col:  col,
+       }
+}
diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go
index ba66a93..040a0b2 100644
--- a/spark/sql/column/column_test.go
+++ b/spark/sql/column/column_test.go
@@ -338,3 +338,35 @@ func TestColumnFunctions(t *testing.T) {
                })
        }
 }
+
+func TestColumn_Alias(t *testing.T) {
+       col1 := NewColumn(NewColumnReference("col1"))
+       col1Plan, _ := col1.ToProto(context.Background())
+
+       tests := []struct {
+               name string
+               arg  Convertible
+               want *proto.Expression
+       }{
+               {
+                       name: "TestColumnAlias",
+                       arg:  WithAlias("alias", col1),
+                       want: &proto.Expression{
+                               ExprType: &proto.Expression_Alias_{
+                                       Alias: &proto.Expression_Alias{
+                                               Expr: col1Plan,
+                                               Name: []string{"alias"},
+                                       },
+                               },
+                       },
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       got, err := tt.arg.ToProto(context.Background())
+                       assert.NoError(t, err)
+                       expected := tt.want
+                       assert.Equalf(t, expected, got, "Input: %v", tt.arg)
+               })
+       }
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 85792b7..55e3ed6 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -198,7 +198,7 @@ type DataFrame interface {
        // plans which can cause performance issues and even 
`StackOverflowException`.
        // To avoid this, use :func:`select` with multiple columns at once.
        WithColumn(ctx context.Context, colName string, col column.Convertible) 
(DataFrame, error)
-       WithColumns(ctx context.Context, colsMap map[string]column.Convertible) 
(DataFrame, error)
+       WithColumns(ctx context.Context, alias ...column.Alias) (DataFrame, 
error)
        // WithColumnRenamed returns a new DataFrame by renaming an existing 
column.
        // This is a no-op if the schema doesn't contain the given column name.
        WithColumnRenamed(ctx context.Context, existingName, newName string) 
(DataFrame, error)
@@ -665,21 +665,19 @@ func (df *dataFrameImpl) GroupBy(cols 
...column.Convertible) *GroupedData {
 }
 
 func (df *dataFrameImpl) WithColumn(ctx context.Context, colName string, col 
column.Convertible) (DataFrame, error) {
-       return df.WithColumns(ctx, map[string]column.Convertible{colName: col})
+       return df.WithColumns(ctx, column.WithAlias(colName, col))
 }
 
-func (df *dataFrameImpl) WithColumns(ctx context.Context, colsMap 
map[string]column.Convertible) (DataFrame, error) {
+func (df *dataFrameImpl) WithColumns(ctx context.Context, cols 
...column.Alias) (DataFrame, error) {
        // Convert all columns to proto expressions and the corresponding alias:
-       aliases := make([]*proto.Expression_Alias, 0, len(colsMap))
-       for colName, col := range colsMap {
+       aliases := make([]*proto.Expression_Alias, 0, len(cols))
+       for _, col := range cols {
                expr, err := col.ToProto(ctx)
                if err != nil {
                        return nil, err
                }
-               alias := &proto.Expression_Alias{
-                       Expr: expr,
-                       Name: []string{colName},
-               }
+               // The alias must be an alias expression.
+               alias := expr.GetAlias()
                aliases = append(aliases, alias)
        }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to