This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 594f0b7  Improve Package Structure
594f0b7 is described below

commit 594f0b74e80638a20a059f4f91032f7bfe09a621
Author: Martin Grund <[email protected]>
AuthorDate: Wed Aug 28 16:50:46 2024 +0200

    Improve Package Structure
    
    ### What changes were proposed in this pull request?
    This patch fixes some of the issues in the current code and the way the 
interfaces are used and exposed to the other packages.  In particular, now all 
DF functions should take a context to deal with resolving schema etc 
dynamically and for API consistency.
    
    Second, all columns and expressions need to support the 
`column.Convertible` interface to indicate that it can be converted to Proto. 
Lastly, the column package adds a `HasSchema` like interface that allows to 
pass a DF without passing the actual interface type.
    
    Lastly it changes the `ToPlan` method on `Column` / `Expression` / 
`Convertible` to `ToProto` for more clarity.
    
    ### Why are the changes needed?
    Ease of use.
    
    ### Does this PR introduce _any_ user-facing change?
    Slight package changes. But still pre-release.
    
    ### How was this patch tested?
    Existing tests.
    
    Closes #65 from grundprinzip/package_refactorign.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 cmd/spark-connect-example-spark-session/main.go |  19 ++--
 internal/tests/integration/dataframe_test.go    |  18 +++-
 internal/tests/integration/functions_test.go    |   5 +-
 internal/tests/integration/spark_runner.go      |   5 +-
 internal/tests/integration/sql_test.go          |   9 +-
 spark/sql/column/column.go                      |  66 ++++++++++---
 spark/sql/column/column_test.go                 |  18 ++--
 spark/sql/column/expressions.go                 | 126 ++++++++++++++++++------
 spark/sql/column/expressions_test.go            |  19 ++--
 spark/sql/dataframe.go                          |  85 ++++++----------
 spark/sql/functions/buiitins.go                 |   4 +-
 spark/sql/functions/generated.go                |   4 +-
 spark/sql/sparksession_test.go                  |   2 +-
 13 files changed, 238 insertions(+), 142 deletions(-)

diff --git a/cmd/spark-connect-example-spark-session/main.go 
b/cmd/spark-connect-example-spark-session/main.go
index 163009f..6a1fe6a 100644
--- a/cmd/spark-connect-example-spark-session/main.go
+++ b/cmd/spark-connect-example-spark-session/main.go
@@ -50,7 +50,7 @@ func main() {
                log.Fatalf("Failed: %s", err)
        }
 
-       df, _ = df.FilterByString("id < 10")
+       df, _ = df.FilterByString(ctx, "id < 10")
        err = df.Show(ctx, 100, false)
        if err != nil {
                log.Fatalf("Failed: %s", err)
@@ -61,14 +61,14 @@ func main() {
                log.Fatalf("Failed: %s", err)
        }
 
-       df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10")))
+       df, _ = df.Filter(ctx, functions.Col("id").Lt(functions.Expr("10")))
        err = df.Show(ctx, 100, false)
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
 
        df, _ = spark.Sql(ctx, "select * from range(100)")
-       df, err = df.Filter(functions.Col("id").Lt(functions.Lit(20)))
+       df, err = df.Filter(ctx, functions.Col("id").Lt(functions.Lit(20)))
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
@@ -151,7 +151,7 @@ func main() {
        }
 
        log.Printf("Repartition with one partition")
-       df, err = df.Repartition(1, nil)
+       df, err = df.Repartition(ctx, 1, nil)
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
@@ -164,7 +164,7 @@ func main() {
        }
 
        log.Printf("Repartition with two partitions")
-       df, err = df.Repartition(2, nil)
+       df, err = df.Repartition(ctx, 2, nil)
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
@@ -177,7 +177,7 @@ func main() {
        }
 
        log.Printf("Repartition with columns")
-       df, err = df.Repartition(0, []string{"word", "count"})
+       df, err = df.Repartition(ctx, 0, []string{"word", "count"})
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
@@ -190,12 +190,7 @@ func main() {
        }
 
        log.Printf("Repartition by range with columns")
-       df, err = df.RepartitionByRange(0, []sql.RangePartitionColumn{
-               {
-                       Name:       "word",
-                       Descending: true,
-               },
-       })
+       df, err = df.RepartitionByRange(ctx, 0, functions.Col("word").Desc())
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index b6fa117..9fe2e3d 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -19,8 +19,11 @@ import (
        "context"
        "testing"
 
-       "github.com/apache/spark-connect-go/v35/spark/sql"
+       "github.com/apache/spark-connect-go/v35/spark/sql/column"
+
        "github.com/apache/spark-connect-go/v35/spark/sql/functions"
+
+       "github.com/apache/spark-connect-go/v35/spark/sql"
        "github.com/stretchr/testify/assert"
 )
 
@@ -30,7 +33,7 @@ func TestDataFrame_Select(t *testing.T) {
        assert.NoError(t, err)
        df, err := spark.Sql(ctx, "select * from range(100)")
        assert.NoError(t, err)
-       df, err = df.Select(functions.Lit("1"), functions.Lit("2"))
+       df, err = df.Select(ctx, functions.Lit("1"), functions.Lit("2"))
        assert.NoError(t, err)
 
        res, err := df.Collect(ctx)
@@ -41,6 +44,11 @@ func TestDataFrame_Select(t *testing.T) {
        vals, err := row_zero.Values()
        assert.NoError(t, err)
        assert.Equal(t, 2, len(vals))
+
+       df, err = spark.Sql(ctx, "select * from range(100)")
+       assert.NoError(t, err)
+       _, err = df.Select(ctx, column.OfDF(df, "id2"))
+       assert.Error(t, err)
 }
 
 func TestDataFrame_SelectExpr(t *testing.T) {
@@ -49,7 +57,7 @@ func TestDataFrame_SelectExpr(t *testing.T) {
        assert.NoError(t, err)
        df, err := spark.Sql(ctx, "select * from range(100)")
        assert.NoError(t, err)
-       df, err = df.SelectExpr("1", "2", "spark_partition_id()")
+       df, err = df.SelectExpr(ctx, "1", "2", "spark_partition_id()")
        assert.NoError(t, err)
 
        res, err := df.Collect(ctx)
@@ -68,7 +76,7 @@ func TestDataFrame_Alias(t *testing.T) {
        assert.NoError(t, err)
        df, err := spark.Sql(ctx, "select * from range(100)")
        assert.NoError(t, err)
-       df = df.Alias("df")
+       df = df.Alias(ctx, "df")
        res, er := df.Collect(ctx)
        assert.NoError(t, er)
        assert.Equal(t, 100, len(res))
@@ -82,7 +90,7 @@ func TestDataFrame_CrossJoin(t *testing.T) {
        assert.NoError(t, err)
        df2, err := spark.Sql(ctx, "select * from range(10)")
        assert.NoError(t, err)
-       df := df1.CrossJoin(df2)
+       df := df1.CrossJoin(ctx, df2)
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Equal(t, 100, len(res))
diff --git a/internal/tests/integration/functions_test.go 
b/internal/tests/integration/functions_test.go
index 659cb21..1f89923 100644
--- a/internal/tests/integration/functions_test.go
+++ b/internal/tests/integration/functions_test.go
@@ -19,8 +19,9 @@ import (
        "context"
        "testing"
 
-       "github.com/apache/spark-connect-go/v35/spark/sql"
        "github.com/apache/spark-connect-go/v35/spark/sql/functions"
+
+       "github.com/apache/spark-connect-go/v35/spark/sql"
        "github.com/stretchr/testify/assert"
 )
 
@@ -32,7 +33,7 @@ func TestIntegration_BuiltinFunctions(t *testing.T) {
        }
 
        df, _ := spark.Sql(ctx, "select '[2]' as a from range(10)")
-       df, _ = 
df.Filter(functions.JsonArrayLength(functions.Col("a")).Eq(functions.Lit(1)))
+       df, _ = df.Filter(ctx, 
functions.JsonArrayLength(functions.Col("a")).Eq(functions.Lit(1)))
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Equal(t, 10, len(res))
diff --git a/internal/tests/integration/spark_runner.go 
b/internal/tests/integration/spark_runner.go
index 060693f..c4568d2 100644
--- a/internal/tests/integration/spark_runner.go
+++ b/internal/tests/integration/spark_runner.go
@@ -34,10 +34,13 @@ func StartSparkConnect() (int64, error) {
 
        fmt.Printf("Starting Spark Connect Server in: %v\n", 
os.Getenv("SPARK_HOME"))
 
-       cmd := exec.Command("./sbin/start-connect-server.sh", "--wait", 
"--conf",
+       cmd := exec.Command("./sbin/start-connect-server.sh", "--conf",
                "spark.log.structuredLogging.enabled=false", "--packages",
                "org.apache.spark:spark-connect_2.12:3.5.2")
        cmd.Dir = sparkHome
+       baseEnv := os.Environ()
+       baseEnv = append(baseEnv, "SPARK_NO_DAEMONIZE=1")
+       cmd.Env = baseEnv
 
        stdout, _ := cmd.StdoutPipe()
        if err := cmd.Start(); err != nil {
diff --git a/internal/tests/integration/sql_test.go 
b/internal/tests/integration/sql_test.go
index eddef41..c0235de 100644
--- a/internal/tests/integration/sql_test.go
+++ b/internal/tests/integration/sql_test.go
@@ -21,10 +21,13 @@ import (
        "os"
        "testing"
 
+       "github.com/apache/spark-connect-go/v35/spark/sql/column"
+
+       "github.com/apache/spark-connect-go/v35/spark/sql/functions"
+
        "github.com/apache/spark-connect-go/v35/spark/sql/types"
 
        "github.com/apache/spark-connect-go/v35/spark/sql"
-       "github.com/apache/spark-connect-go/v35/spark/sql/functions"
        "github.com/stretchr/testify/assert"
 )
 
@@ -42,9 +45,7 @@ func TestIntegration_RunSQLCommand(t *testing.T) {
        assert.NoError(t, err)
        assert.Equal(t, 100, len(res))
 
-       col, err := df.Col("id")
-       assert.NoError(t, err)
-       df, err = df.Filter(col.Lt(functions.Lit(10)))
+       df, err = df.Filter(ctx, column.OfDF(df, "id").Lt(functions.Lit(10)))
        assert.NoError(t, err)
        res, err = df.Collect(ctx)
        assert.NoErrorf(t, err, "Must be able to collect the rows.")
diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go
index 8b37b1a..aae8fd1 100644
--- a/spark/sql/column/column.go
+++ b/spark/sql/column/column.go
@@ -15,51 +15,87 @@
 
 package column
 
-import proto "github.com/apache/spark-connect-go/v35/internal/generated"
+import (
+       "context"
+
+       "github.com/apache/spark-connect-go/v35/spark/sql/types"
+
+       proto "github.com/apache/spark-connect-go/v35/internal/generated"
+)
+
+// Convertible is the interface for all things that can be converted into a 
protobuf expression.
+type Convertible interface {
+       ToProto(ctx context.Context) (*proto.Expression, error)
+}
 
 type Column struct {
-       Expr Expression
+       expr expression
 }
 
-func (c *Column) ToPlan() (*proto.Expression, error) {
-       return c.Expr.ToPlan()
+func (c Column) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return c.expr.ToProto(ctx)
 }
 
 func (c Column) Lt(other Column) Column {
-       return NewColumn(NewUnresolvedFunction("<", []Expression{c.Expr, 
other.Expr}, false))
+       return NewColumn(NewUnresolvedFunction("<", []expression{c.expr, 
other.expr}, false))
 }
 
 func (c Column) Le(other Column) Column {
-       return NewColumn(NewUnresolvedFunction("<=", []Expression{c.Expr, 
other.Expr}, false))
+       return NewColumn(NewUnresolvedFunction("<=", []expression{c.expr, 
other.expr}, false))
 }
 
 func (c Column) Gt(other Column) Column {
-       return NewColumn(NewUnresolvedFunction(">", []Expression{c.Expr, 
other.Expr}, false))
+       return NewColumn(NewUnresolvedFunction(">", []expression{c.expr, 
other.expr}, false))
 }
 
 func (c Column) Ge(other Column) Column {
-       return NewColumn(NewUnresolvedFunction(">=", []Expression{c.Expr, 
other.Expr}, false))
+       return NewColumn(NewUnresolvedFunction(">=", []expression{c.expr, 
other.expr}, false))
 }
 
 func (c Column) Eq(other Column) Column {
-       return NewColumn(NewUnresolvedFunction("==", []Expression{c.Expr, 
other.Expr}, false))
+       return NewColumn(NewUnresolvedFunction("==", []expression{c.expr, 
other.expr}, false))
 }
 
 func (c Column) Neq(other Column) Column {
-       cmp := NewUnresolvedFunction("==", []Expression{c.Expr, other.Expr}, 
false)
-       return NewColumn(NewUnresolvedFunction("not", []Expression{cmp}, false))
+       cmp := NewUnresolvedFunction("==", []expression{c.expr, other.expr}, 
false)
+       return NewColumn(NewUnresolvedFunction("not", []expression{cmp}, false))
 }
 
 func (c Column) Mul(other Column) Column {
-       return NewColumn(NewUnresolvedFunction("*", []Expression{c.Expr, 
other.Expr}, false))
+       return NewColumn(NewUnresolvedFunction("*", []expression{c.expr, 
other.expr}, false))
 }
 
 func (c Column) Div(other Column) Column {
-       return NewColumn(NewUnresolvedFunction("/", []Expression{c.Expr, 
other.Expr}, false))
+       return NewColumn(NewUnresolvedFunction("/", []expression{c.expr, 
other.expr}, false))
+}
+
+func (c Column) Desc() Column {
+       return NewColumn(&sortExpression{
+               child:        c.expr,
+               direction:    
proto.Expression_SortOrder_SORT_DIRECTION_DESCENDING,
+               nullOrdering: proto.Expression_SortOrder_SORT_NULLS_LAST,
+       })
 }
 
-func NewColumn(expr Expression) Column {
+func (c Column) Asc() Column {
+       return NewColumn(&sortExpression{
+               child:        c.expr,
+               direction:    
proto.Expression_SortOrder_SORT_DIRECTION_ASCENDING,
+               nullOrdering: proto.Expression_SortOrder_SORT_NULLS_FIRST,
+       })
+}
+
+func NewColumn(expr expression) Column {
        return Column{
-               Expr: expr,
+               expr: expr,
        }
 }
+
+type SchemaDataFrame interface {
+       PlanId() int64
+       Schema(ctx context.Context) (*types.StructType, error)
+}
+
+func OfDF(df SchemaDataFrame, colName string) Column {
+       return NewColumn(&delayedColumnReference{colName, df})
+}
diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go
index 1917461..b823921 100644
--- a/spark/sql/column/column_test.go
+++ b/spark/sql/column/column_test.go
@@ -16,6 +16,7 @@
 package column
 
 import (
+       "context"
        "testing"
 
        proto "github.com/apache/spark-connect-go/v35/internal/generated"
@@ -23,14 +24,15 @@ import (
 )
 
 func TestNewUnresolvedFunction_Basic(t *testing.T) {
+       ctx := context.Background()
        col1 := NewColumn(NewColumnReference("col1"))
        col2 := NewColumn(NewColumnReference("col2"))
-       col1Plan, _ := col1.ToPlan()
-       col2Plan, _ := col2.ToPlan()
+       col1Plan, _ := col1.ToProto(ctx)
+       col2Plan, _ := col2.ToProto(ctx)
 
        type args struct {
                name       string
-               arguments  []Expression
+               arguments  []expression
                isDistinct bool
        }
        tests := []struct {
@@ -42,7 +44,7 @@ func TestNewUnresolvedFunction_Basic(t *testing.T) {
                        name: "TestNewUnresolvedWithArguments",
                        args: args{
                                name:       "id",
-                               arguments:  []Expression{col1.Expr, col2.Expr},
+                               arguments:  []expression{col1.expr, col2.expr},
                                isDistinct: false,
                        },
                        want: &proto.Expression{
@@ -62,7 +64,7 @@ func TestNewUnresolvedFunction_Basic(t *testing.T) {
                        name: "TestNewUnresolvedWithArgumentsEmpty",
                        args: args{
                                name:       "id",
-                               arguments:  []Expression{},
+                               arguments:  []expression{},
                                isDistinct: true,
                        },
                        want: &proto.Expression{
@@ -79,7 +81,7 @@ func TestNewUnresolvedFunction_Basic(t *testing.T) {
                t.Run(tt.name, func(t *testing.T) {
                        got := NewUnresolvedFunction(tt.args.name, 
tt.args.arguments, tt.args.isDistinct)
                        expected := tt.want
-                       p, err := got.ToPlan()
+                       p, err := got.ToProto(ctx)
                        assert.NoError(t, err)
                        assert.Equalf(t, expected, p, "Input: %v", tt.args)
                })
@@ -315,10 +317,10 @@ func TestColumnFunctions(t *testing.T) {
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       got, err := tt.arg.ToPlan()
+                       got, err := tt.arg.ToProto(context.Background())
                        assert.NoError(t, err)
                        expected := tt.want
-                       assert.Equalf(t, expected, got, "Input: %v", 
tt.arg.Expr.DebugString())
+                       assert.Equalf(t, expected, got, "Input: %v", 
tt.arg.expr.DebugString())
                })
        }
 }
diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go
index 03bac2e..13f3232 100644
--- a/spark/sql/column/expressions.go
+++ b/spark/sql/column/expressions.go
@@ -16,6 +16,7 @@
 package column
 
 import (
+       "context"
        "fmt"
        "strings"
 
@@ -28,23 +29,88 @@ func newProtoExpression() *proto.Expression {
        return &proto.Expression{}
 }
 
-// Expression is the interface for all expressions used by Spark Connect.
-type Expression interface {
-       ToPlan() (*proto.Expression, error)
+// expression is the interface for all expressions used by Spark Connect.
+type expression interface {
+       ToProto(context.Context) (*proto.Expression, error)
        DebugString() string
 }
 
+type delayedColumnReference struct {
+       unparsedIdentifier string
+       df                 SchemaDataFrame
+}
+
+func (d *delayedColumnReference) DebugString() string {
+       return d.unparsedIdentifier
+}
+
+func (d *delayedColumnReference) ToProto(ctx context.Context) 
(*proto.Expression, error) {
+       // Check if the column identifier is actually part of the schema.
+       schema, err := d.df.Schema(ctx)
+       if err != nil {
+               return nil, err
+       }
+       found := false
+       for _, field := range schema.Fields {
+               if field.Name == d.unparsedIdentifier {
+                       found = true
+                       break
+               }
+       }
+       // TODO: return proper pyspark error
+       if !found {
+               return nil, sparkerrors.WithType(sparkerrors.InvalidPlanError,
+                       fmt.Errorf("cannot resolve column %s", 
d.unparsedIdentifier))
+       }
+
+       expr := newProtoExpression()
+       id := d.df.PlanId()
+       expr.ExprType = &proto.Expression_UnresolvedAttribute_{
+               UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{
+                       UnparsedIdentifier: d.unparsedIdentifier,
+                       PlanId:             &id,
+               },
+       }
+       return expr, nil
+}
+
+type sortExpression struct {
+       child        expression
+       direction    proto.Expression_SortOrder_SortDirection
+       nullOrdering proto.Expression_SortOrder_NullOrdering
+}
+
+func (s *sortExpression) DebugString() string {
+       return s.child.DebugString()
+}
+
+func (s *sortExpression) ToProto(ctx context.Context) (*proto.Expression, 
error) {
+       exp := newProtoExpression()
+       child, err := s.child.ToProto(ctx)
+       if err != nil {
+               return nil, err
+       }
+       exp.ExprType = &proto.Expression_SortOrder_{
+               SortOrder: &proto.Expression_SortOrder{
+                       Child:        child,
+                       Direction:    s.direction,
+                       NullOrdering: s.nullOrdering,
+               },
+       }
+       return exp, nil
+}
+
 type caseWhenExpression struct {
        branches []*caseWhenBranch
-       elseExpr Expression
+       elseExpr expression
 }
 
 type caseWhenBranch struct {
-       condition Expression
-       value     Expression
+       condition expression
+       value     expression
 }
 
-func NewCaseWhenExpression(branches []*caseWhenBranch, elseExpr Expression) 
Expression {
+func NewCaseWhenExpression(branches []*caseWhenBranch, elseExpr expression) 
expression {
        return &caseWhenExpression{branches: branches, elseExpr: elseExpr}
 }
 
@@ -63,8 +129,8 @@ func (c *caseWhenExpression) DebugString() string {
        return fmt.Sprintf("CASE %s %s END", strings.Join(branches, " "), 
elseExpr)
 }
 
-func (c *caseWhenExpression) ToPlan() (*proto.Expression, error) {
-       args := make([]Expression, 0)
+func (c *caseWhenExpression) ToProto(ctx context.Context) (*proto.Expression, 
error) {
+       args := make([]expression, 0)
        for _, branch := range c.branches {
                args = append(args, branch.condition)
                args = append(args, branch.value)
@@ -75,12 +141,12 @@ func (c *caseWhenExpression) ToPlan() (*proto.Expression, 
error) {
        }
 
        fun := NewUnresolvedFunction("when", args, false)
-       return fun.ToPlan()
+       return fun.ToProto(ctx)
 }
 
 type unresolvedFunction struct {
        name       string
-       args       []Expression
+       args       []expression
        isDistinct bool
 }
 
@@ -98,13 +164,13 @@ func (u *unresolvedFunction) DebugString() string {
        return fmt.Sprintf("%s(%s%s)", u.name, distinct, strings.Join(args, ", 
"))
 }
 
-func (u *unresolvedFunction) ToPlan() (*proto.Expression, error) {
-       // Convert input args to the proto Expression.
+func (u *unresolvedFunction) ToProto(ctx context.Context) (*proto.Expression, 
error) {
+       // Convert input args to the proto expression.
        var args []*proto.Expression = nil
        if len(u.args) > 0 {
                args = make([]*proto.Expression, 0)
                for _, arg := range u.args {
-                       p, e := arg.ToPlan()
+                       p, e := arg.ToProto(ctx)
                        if e != nil {
                                return nil, e
                        }
@@ -123,29 +189,29 @@ func (u *unresolvedFunction) ToPlan() (*proto.Expression, 
error) {
        return expr, nil
 }
 
-func NewUnresolvedFunction(name string, args []Expression, isDistinct bool) 
Expression {
+func NewUnresolvedFunction(name string, args []expression, isDistinct bool) 
expression {
        return &unresolvedFunction{name: name, args: args, isDistinct: 
isDistinct}
 }
 
-func NewUnresolvedFunctionWithColumns(name string, cols ...Column) Expression {
-       exprs := make([]Expression, 0)
+func NewUnresolvedFunctionWithColumns(name string, cols ...Column) expression {
+       exprs := make([]expression, 0)
        for _, col := range cols {
-               exprs = append(exprs, col.Expr)
+               exprs = append(exprs, col.expr)
        }
        return NewUnresolvedFunction(name, exprs, false)
 }
 
 type columnAlias struct {
        alias    []string
-       expr     Expression
+       expr     expression
        metadata *string
 }
 
-func NewColumnAlias(alias string, expr Expression) Expression {
+func NewColumnAlias(alias string, expr expression) expression {
        return &columnAlias{alias: []string{alias}, expr: expr}
 }
 
-func NewColumnAliasFromNameParts(alias []string, expr Expression) Expression {
+func NewColumnAliasFromNameParts(alias []string, expr expression) expression {
        return &columnAlias{alias: alias, expr: expr}
 }
 
@@ -155,9 +221,9 @@ func (c *columnAlias) DebugString() string {
        return fmt.Sprintf("%s AS %s", child, alias)
 }
 
-func (c *columnAlias) ToPlan() (*proto.Expression, error) {
+func (c *columnAlias) ToProto(ctx context.Context) (*proto.Expression, error) {
        expr := newProtoExpression()
-       alias, err := c.expr.ToPlan()
+       alias, err := c.expr.ToProto(ctx)
        if err != nil {
                return nil, err
        }
@@ -176,11 +242,11 @@ type columnReference struct {
        planId             *int64
 }
 
-func NewColumnReference(unparsedIdentifier string) Expression {
+func NewColumnReference(unparsedIdentifier string) expression {
        return &columnReference{unparsedIdentifier: unparsedIdentifier}
 }
 
-func NewColumnReferenceWithPlanId(unparsedIdentifier string, planId int64) 
Expression {
+func NewColumnReferenceWithPlanId(unparsedIdentifier string, planId int64) 
expression {
        return &columnReference{unparsedIdentifier: unparsedIdentifier, planId: 
&planId}
 }
 
@@ -188,7 +254,7 @@ func (c *columnReference) DebugString() string {
        return c.unparsedIdentifier
 }
 
-func (c *columnReference) ToPlan() (*proto.Expression, error) {
+func (c *columnReference) ToProto(context.Context) (*proto.Expression, error) {
        expr := newProtoExpression()
        expr.ExprType = &proto.Expression_UnresolvedAttribute_{
                UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{
@@ -203,7 +269,7 @@ type sqlExression struct {
        expression_string string
 }
 
-func NewSQLExpression(expression string) Expression {
+func NewSQLExpression(expression string) expression {
        return &sqlExression{expression_string: expression}
 }
 
@@ -211,7 +277,7 @@ func (s *sqlExression) DebugString() string {
        return s.expression_string
 }
 
-func (s *sqlExression) ToPlan() (*proto.Expression, error) {
+func (s *sqlExression) ToProto(context.Context) (*proto.Expression, error) {
        expr := newProtoExpression()
        expr.ExprType = &proto.Expression_ExpressionString_{
                ExpressionString: &proto.Expression_ExpressionString{
@@ -229,7 +295,7 @@ func (l *literalExpression) DebugString() string {
        return fmt.Sprintf("%v", l.value)
 }
 
-func (l *literalExpression) ToPlan() (*proto.Expression, error) {
+func (l *literalExpression) ToProto(context.Context) (*proto.Expression, 
error) {
        expr := newProtoExpression()
        expr.ExprType = &proto.Expression_Literal_{
                Literal: &proto.Expression_Literal{},
@@ -268,6 +334,6 @@ func (l *literalExpression) ToPlan() (*proto.Expression, 
error) {
        return expr, nil
 }
 
-func NewLiteral(value any) Expression {
+func NewLiteral(value any) expression {
        return &literalExpression{value: value}
 }
diff --git a/spark/sql/column/expressions_test.go 
b/spark/sql/column/expressions_test.go
index 7c75c91..5cf44da 100644
--- a/spark/sql/column/expressions_test.go
+++ b/spark/sql/column/expressions_test.go
@@ -16,6 +16,7 @@
 package column
 
 import (
+       "context"
        "reflect"
        "testing"
 
@@ -25,10 +26,10 @@ import (
 
 func TestNewUnresolvedFunction(t *testing.T) {
        colRef := NewColumnReference("martin")
-       colRefPlan, _ := colRef.ToPlan()
+       colRefPlan, _ := colRef.ToProto(context.Background())
        type args struct {
                name       string
-               arguments  []Expression
+               arguments  []expression
                isDistinct bool
        }
        tests := []struct {
@@ -56,7 +57,7 @@ func TestNewUnresolvedFunction(t *testing.T) {
                        name: "TestNewUnresolvedWithArguments",
                        args: args{
                                name:       "id",
-                               arguments:  []Expression{colRef},
+                               arguments:  []expression{colRef},
                                isDistinct: false,
                        },
                        want: &proto.Expression{
@@ -74,7 +75,8 @@ func TestNewUnresolvedFunction(t *testing.T) {
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       got, err := NewUnresolvedFunction(tt.args.name, 
tt.args.arguments, tt.args.isDistinct).ToPlan()
+                       got, err := NewUnresolvedFunction(tt.args.name, 
tt.args.arguments,
+                               
tt.args.isDistinct).ToProto(context.Background())
                        assert.NoError(t, err)
                        if !reflect.DeepEqual(got, tt.want) {
                                assert.Equal(t, tt.want, got)
@@ -86,7 +88,7 @@ func TestNewUnresolvedFunction(t *testing.T) {
 
 func TestNewUnresolvedFunctionWithColumns(t *testing.T) {
        colRef := NewColumn(NewColumnReference("martin"))
-       colRefPlan, _ := colRef.ToPlan()
+       colRefPlan, _ := colRef.ToProto(context.Background())
 
        type args struct {
                name      string
@@ -153,7 +155,8 @@ func TestNewUnresolvedFunctionWithColumns(t *testing.T) {
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       got, err := 
NewUnresolvedFunctionWithColumns(tt.args.name, tt.args.arguments...).ToPlan()
+                       got, err := 
NewUnresolvedFunctionWithColumns(tt.args.name,
+                               
tt.args.arguments...).ToProto(context.Background())
                        assert.NoError(t, err)
                        if !reflect.DeepEqual(got, tt.want) {
                                assert.Equal(t, tt.want, got)
@@ -193,9 +196,9 @@ func TestNewSQLExpression(t *testing.T) {
 
 func TestColumnAlias_Basic(t *testing.T) {
        colRef := NewColumnReference("column")
-       colRefPlan, _ := colRef.ToPlan()
+       colRefPlan, _ := colRef.ToProto(context.Background())
        colAlias := NewColumnAlias("martin", colRef)
-       colAliasPlan, _ := colAlias.ToPlan()
+       colAliasPlan, _ := colAlias.ToProto(context.Background())
        assert.Equal(t, colRefPlan, colAliasPlan.GetAlias().GetExpr())
 
        // Test the debug string:
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index eb285db..bf44ab1 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -21,7 +21,6 @@ import (
        "fmt"
 
        "github.com/apache/spark-connect-go/v35/spark/sql/column"
-       "github.com/apache/spark-connect-go/v35/spark/sql/functions"
 
        "github.com/apache/spark-connect-go/v35/spark/sql/types"
 
@@ -37,6 +36,8 @@ type ResultCollector interface {
 
 // DataFrame is a wrapper for data frame, representing a distributed 
collection of data row.
 type DataFrame interface {
+       // PlanId returns the plan id of the data frame.
+       PlanId() int64
        // WriteResult streams the data frames to a result collector
        WriteResult(ctx context.Context, collector ResultCollector, numRows 
int, truncate bool) error
        // Show uses WriteResult to write the data frames to the console output.
@@ -53,29 +54,21 @@ type DataFrame interface {
        // CreateTempView creates or replaces a temporary view.
        CreateTempView(ctx context.Context, viewName string, replace, global 
bool) error
        // Repartition re-partitions a data frame.
-       Repartition(numPartitions int, columns []string) (DataFrame, error)
+       Repartition(ctx context.Context, numPartitions int, columns []string) 
(DataFrame, error)
        // RepartitionByRange re-partitions a data frame by range partition.
-       RepartitionByRange(numPartitions int, columns []RangePartitionColumn) 
(DataFrame, error)
+       RepartitionByRange(ctx context.Context, numPartitions int, columns 
...column.Convertible) (DataFrame, error)
        // Filter filters the data frame by a column condition.
-       Filter(condition column.Column) (DataFrame, error)
+       Filter(ctx context.Context, condition column.Convertible) (DataFrame, 
error)
        // FilterByString filters the data frame by a string condition.
-       FilterByString(condition string) (DataFrame, error)
-       // Col returns a column by name.
-       Col(name string) (column.Column, error)
-
+       FilterByString(ctx context.Context, condition string) (DataFrame, error)
        // Select projects a list of columns from the DataFrame
-       Select(columns ...column.Column) (DataFrame, error)
+       Select(ctx context.Context, columns ...column.Convertible) (DataFrame, 
error)
        // SelectExpr projects a list of columns from the DataFrame by string 
expressions
-       SelectExpr(exprs ...string) (DataFrame, error)
+       SelectExpr(ctx context.Context, exprs ...string) (DataFrame, error)
        // Alias creates a new DataFrame with the specified subquery alias
-       Alias(alias string) DataFrame
+       Alias(ctx context.Context, alias string) DataFrame
        // CrossJoin joins the current DataFrame with another DataFrame using 
the cross product
-       CrossJoin(other DataFrame) DataFrame
-}
-
-type RangePartitionColumn struct {
-       Name       string
-       Descending bool
+       CrossJoin(ctx context.Context, other DataFrame) DataFrame
 }
 
 // dataFrameImpl is an implementation of DataFrame interface.
@@ -84,11 +77,15 @@ type dataFrameImpl struct {
        relation *proto.Relation // TODO change to proto.Plan?
 }
 
-func (df *dataFrameImpl) SelectExpr(exprs ...string) (DataFrame, error) {
+func (df *dataFrameImpl) PlanId() int64 {
+       return df.relation.GetCommon().GetPlanId()
+}
+
+func (df *dataFrameImpl) SelectExpr(ctx context.Context, exprs ...string) 
(DataFrame, error) {
        expressions := make([]*proto.Expression, 0, len(exprs))
        for _, expr := range exprs {
-               col := functions.Expr(expr)
-               f, e := col.ToPlan()
+               col := column.NewSQLExpression(expr)
+               f, e := col.ToProto(ctx)
                if e != nil {
                        return nil, e
                }
@@ -109,7 +106,7 @@ func (df *dataFrameImpl) SelectExpr(exprs ...string) 
(DataFrame, error) {
        return NewDataFrame(df.session, rel), nil
 }
 
-func (df *dataFrameImpl) Alias(alias string) DataFrame {
+func (df *dataFrameImpl) Alias(ctx context.Context, alias string) DataFrame {
        rel := &proto.Relation{
                Common: &proto.RelationCommon{
                        PlanId: newPlanId(),
@@ -124,7 +121,7 @@ func (df *dataFrameImpl) Alias(alias string) DataFrame {
        return NewDataFrame(df.session, rel)
 }
 
-func (df *dataFrameImpl) CrossJoin(other DataFrame) DataFrame {
+func (df *dataFrameImpl) CrossJoin(ctx context.Context, other DataFrame) 
DataFrame {
        otherDf := other.(*dataFrameImpl)
        rel := &proto.Relation{
                Common: &proto.RelationCommon{
@@ -287,7 +284,7 @@ func (df *dataFrameImpl) CreateTempView(ctx 
context.Context, viewName string, re
        return err
 }
 
-func (df *dataFrameImpl) Repartition(numPartitions int, columns []string) 
(DataFrame, error) {
+func (df *dataFrameImpl) Repartition(ctx context.Context, numPartitions int, 
columns []string) (DataFrame, error) {
        var partitionExpressions []*proto.Expression
        if columns != nil {
                partitionExpressions = make([]*proto.Expression, 0, 
len(columns))
@@ -305,31 +302,16 @@ func (df *dataFrameImpl) Repartition(numPartitions int, 
columns []string) (DataF
        return df.repartitionByExpressions(numPartitions, partitionExpressions)
 }
 
-func (df *dataFrameImpl) RepartitionByRange(numPartitions int, columns 
[]RangePartitionColumn) (DataFrame, error) {
+func (df *dataFrameImpl) RepartitionByRange(ctx context.Context, numPartitions 
int, columns ...column.Convertible) (DataFrame, error) {
        var partitionExpressions []*proto.Expression
        if columns != nil {
                partitionExpressions = make([]*proto.Expression, 0, 
len(columns))
                for _, c := range columns {
-                       columnExpr := &proto.Expression{
-                               ExprType: 
&proto.Expression_UnresolvedAttribute_{
-                                       UnresolvedAttribute: 
&proto.Expression_UnresolvedAttribute{
-                                               UnparsedIdentifier: c.Name,
-                                       },
-                               },
-                       }
-                       direction := 
proto.Expression_SortOrder_SORT_DIRECTION_ASCENDING
-                       if c.Descending {
-                               direction = 
proto.Expression_SortOrder_SORT_DIRECTION_DESCENDING
-                       }
-                       sortExpr := &proto.Expression{
-                               ExprType: &proto.Expression_SortOrder_{
-                                       SortOrder: &proto.Expression_SortOrder{
-                                               Child:     columnExpr,
-                                               Direction: direction,
-                                       },
-                               },
+                       expr, err := c.ToProto(ctx)
+                       if err != nil {
+                               return nil, err
                        }
-                       partitionExpressions = append(partitionExpressions, 
sortExpr)
+                       partitionExpressions = append(partitionExpressions, 
expr)
                }
        }
        return df.repartitionByExpressions(numPartitions, partitionExpressions)
@@ -367,8 +349,8 @@ func (df *dataFrameImpl) 
repartitionByExpressions(numPartitions int,
        return NewDataFrame(df.session, newRelation), nil
 }
 
-func (df *dataFrameImpl) Filter(condition column.Column) (DataFrame, error) {
-       cnd, err := condition.ToPlan()
+func (df *dataFrameImpl) Filter(ctx context.Context, condition 
column.Convertible) (DataFrame, error) {
+       cnd, err := condition.ToProto(ctx)
        if err != nil {
                return nil, err
        }
@@ -387,19 +369,14 @@ func (df *dataFrameImpl) Filter(condition column.Column) 
(DataFrame, error) {
        return NewDataFrame(df.session, rel), nil
 }
 
-func (df *dataFrameImpl) FilterByString(condition string) (DataFrame, error) {
-       return df.Filter(functions.Expr(condition))
-}
-
-func (df *dataFrameImpl) Col(name string) (column.Column, error) {
-       planId := df.relation.Common.GetPlanId()
-       return column.NewColumn(column.NewColumnReferenceWithPlanId(name, 
planId)), nil
+func (df *dataFrameImpl) FilterByString(ctx context.Context, condition string) 
(DataFrame, error) {
+       return df.Filter(ctx, 
column.NewColumn(column.NewSQLExpression(condition)))
 }
 
-func (df *dataFrameImpl) Select(columns ...column.Column) (DataFrame, error) {
+func (df *dataFrameImpl) Select(ctx context.Context, columns 
...column.Convertible) (DataFrame, error) {
        exprs := make([]*proto.Expression, 0, len(columns))
        for _, c := range columns {
-               expr, err := c.ToPlan()
+               expr, err := c.ToProto(ctx)
                if err != nil {
                        return nil, err
                }
diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go
index ed1b3d1..a2a7bf8 100644
--- a/spark/sql/functions/buiitins.go
+++ b/spark/sql/functions/buiitins.go
@@ -15,7 +15,9 @@
 
 package functions
 
-import "github.com/apache/spark-connect-go/v35/spark/sql/column"
+import (
+       "github.com/apache/spark-connect-go/v35/spark/sql/column"
+)
 
 func Expr(expr string) column.Column {
        return column.NewColumn(column.NewSQLExpression(expr))
diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go
index 50b17b4..b468cda 100644
--- a/spark/sql/functions/generated.go
+++ b/spark/sql/functions/generated.go
@@ -15,7 +15,9 @@
 
 package functions
 
-import "github.com/apache/spark-connect-go/v35/spark/sql/column"
+import (
+       "github.com/apache/spark-connect-go/v35/spark/sql/column"
+)
 
 // BitwiseNOT - Computes bitwise not.
 //
diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go
index 0ef785d..a78681e 100644
--- a/spark/sql/sparksession_test.go
+++ b/spark/sql/sparksession_test.go
@@ -183,7 +183,7 @@ func TestWriteResultStreamsArrowResultToCollector(t 
*testing.T) {
        resp, err := session.Sql(ctx, query)
        assert.NoError(t, err)
        assert.NotNil(t, resp)
-       df, err := resp.Repartition(1, []string{"1"})
+       df, err := resp.Repartition(ctx, 1, []string{"1"})
        assert.NoError(t, err)
        rows, err := df.Collect(ctx)
        assert.NoError(t, err)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to