This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new e93f73b  [#58] Adding Support for select, selectExpr, crossJoin, and 
Alias
e93f73b is described below

commit e93f73b471f6949f975db02b986f917c1b19f064
Author: Martin Grund <[email protected]>
AuthorDate: Mon Aug 26 10:07:11 2024 +0200

    [#58] Adding Support for select, selectExpr, crossJoin, and Alias
    
    ### What changes were proposed in this pull request?
    Adds support in the DataFrame for:
    
    * `Alias`
    *  `Select`
    *  `SelectExpr`
    *  `CrossJoin`
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    Added new integration tests.
    
    Closes #59 from grundprinzip/df_functions_v1.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go | 93 ++++++++++++++++++++++++++++
 spark/sql/dataframe.go                       | 90 +++++++++++++++++++++++++++
 2 files changed, 183 insertions(+)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
new file mode 100644
index 0000000..b6fa117
--- /dev/null
+++ b/internal/tests/integration/dataframe_test.go
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration
+
+import (
+       "context"
+       "testing"
+
+       "github.com/apache/spark-connect-go/v35/spark/sql"
+       "github.com/apache/spark-connect-go/v35/spark/sql/functions"
+       "github.com/stretchr/testify/assert"
+)
+
+func TestDataFrame_Select(t *testing.T) {
+       ctx := context.Background()
+       spark, err := 
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+       assert.NoError(t, err)
+       df, err := spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       df, err = df.Select(functions.Lit("1"), functions.Lit("2"))
+       assert.NoError(t, err)
+
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 100, len(res))
+
+       row_zero := res[0]
+       vals, err := row_zero.Values()
+       assert.NoError(t, err)
+       assert.Equal(t, 2, len(vals))
+}
+
+func TestDataFrame_SelectExpr(t *testing.T) {
+       ctx := context.Background()
+       spark, err := 
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+       assert.NoError(t, err)
+       df, err := spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       df, err = df.SelectExpr("1", "2", "spark_partition_id()")
+       assert.NoError(t, err)
+
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 100, len(res))
+
+       row_zero := res[0]
+       vals, err := row_zero.Values()
+       assert.NoError(t, err)
+       assert.Equal(t, 3, len(vals))
+}
+
+func TestDataFrame_Alias(t *testing.T) {
+       ctx := context.Background()
+       spark, err := 
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+       assert.NoError(t, err)
+       df, err := spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       df = df.Alias("df")
+       res, er := df.Collect(ctx)
+       assert.NoError(t, er)
+       assert.Equal(t, 100, len(res))
+}
+
+func TestDataFrame_CrossJoin(t *testing.T) {
+       ctx := context.Background()
+       spark, err := 
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+       assert.NoError(t, err)
+       df1, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df2, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df := df1.CrossJoin(df2)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 100, len(res))
+
+       v, e := res[0].Values()
+       assert.NoError(t, e)
+       assert.Equal(t, 2, len(v))
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index cd8ba64..eb285db 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -62,6 +62,15 @@ type DataFrame interface {
        FilterByString(condition string) (DataFrame, error)
        // Col returns a column by name.
        Col(name string) (column.Column, error)
+
+       // Select projects a list of columns from the DataFrame
+       Select(columns ...column.Column) (DataFrame, error)
+       // SelectExpr projects a list of columns from the DataFrame by string 
expressions
+       SelectExpr(exprs ...string) (DataFrame, error)
+       // Alias creates a new DataFrame with the specified subquery alias
+       Alias(alias string) DataFrame
+       // CrossJoin joins the current DataFrame with another DataFrame using 
the cross product
+       CrossJoin(other DataFrame) DataFrame
 }
 
 type RangePartitionColumn struct {
@@ -75,6 +84,63 @@ type dataFrameImpl struct {
        relation *proto.Relation // TODO change to proto.Plan?
 }
 
+func (df *dataFrameImpl) SelectExpr(exprs ...string) (DataFrame, error) {
+       expressions := make([]*proto.Expression, 0, len(exprs))
+       for _, expr := range exprs {
+               col := functions.Expr(expr)
+               f, e := col.ToPlan()
+               if e != nil {
+                       return nil, e
+               }
+               expressions = append(expressions, f)
+       }
+
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Project{
+                       Project: &proto.Project{
+                               Input:       df.relation,
+                               Expressions: expressions,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) Alias(alias string) DataFrame {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SubqueryAlias{
+                       SubqueryAlias: &proto.SubqueryAlias{
+                               Input: df.relation,
+                               Alias: alias,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) CrossJoin(other DataFrame) DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Join{
+                       Join: &proto.Join{
+                               Left:     df.relation,
+                               Right:    otherDf.relation,
+                               JoinType: proto.Join_JOIN_TYPE_CROSS,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
 // NewDataFrame creates a new DataFrame
 func NewDataFrame(session *sparkSessionImpl, relation *proto.Relation) 
DataFrame {
        return &dataFrameImpl{
@@ -329,3 +395,27 @@ func (df *dataFrameImpl) Col(name string) (column.Column, 
error) {
        planId := df.relation.Common.GetPlanId()
        return column.NewColumn(column.NewColumnReferenceWithPlanId(name, 
planId)), nil
 }
+
+func (df *dataFrameImpl) Select(columns ...column.Column) (DataFrame, error) {
+       exprs := make([]*proto.Expression, 0, len(columns))
+       for _, c := range columns {
+               expr, err := c.ToPlan()
+               if err != nil {
+                       return nil, err
+               }
+               exprs = append(exprs, expr)
+       }
+
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Project{
+                       Project: &proto.Project{
+                               Input:       df.relation,
+                               Expressions: exprs,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel), nil
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to