This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new e93f73b [#58] Adding Support for select, selectExpr, crossJoin, and
Alias
e93f73b is described below
commit e93f73b471f6949f975db02b986f917c1b19f064
Author: Martin Grund <[email protected]>
AuthorDate: Mon Aug 26 10:07:11 2024 +0200
[#58] Adding Support for select, selectExpr, crossJoin, and Alias
### What changes were proposed in this pull request?
Adds support in the DataFrame for:
* `Alias`
* `Select`
* `SelectExpr`
* `CrossJoin`
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Added new integration tests.
Closes #59 from grundprinzip/df_functions_v1.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 93 ++++++++++++++++++++++++++++
spark/sql/dataframe.go | 90 +++++++++++++++++++++++++++
2 files changed, 183 insertions(+)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
new file mode 100644
index 0000000..b6fa117
--- /dev/null
+++ b/internal/tests/integration/dataframe_test.go
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration
+
+import (
+ "context"
+ "testing"
+
+ "github.com/apache/spark-connect-go/v35/spark/sql"
+ "github.com/apache/spark-connect-go/v35/spark/sql/functions"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDataFrame_Select(t *testing.T) {
+ ctx := context.Background()
+ spark, err :=
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+ assert.NoError(t, err)
+ df, err := spark.Sql(ctx, "select * from range(100)")
+ assert.NoError(t, err)
+ df, err = df.Select(functions.Lit("1"), functions.Lit("2"))
+ assert.NoError(t, err)
+
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 100, len(res))
+
+ row_zero := res[0]
+ vals, err := row_zero.Values()
+ assert.NoError(t, err)
+ assert.Equal(t, 2, len(vals))
+}
+
+func TestDataFrame_SelectExpr(t *testing.T) {
+ ctx := context.Background()
+ spark, err :=
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+ assert.NoError(t, err)
+ df, err := spark.Sql(ctx, "select * from range(100)")
+ assert.NoError(t, err)
+ df, err = df.SelectExpr("1", "2", "spark_partition_id()")
+ assert.NoError(t, err)
+
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 100, len(res))
+
+ row_zero := res[0]
+ vals, err := row_zero.Values()
+ assert.NoError(t, err)
+ assert.Equal(t, 3, len(vals))
+}
+
+func TestDataFrame_Alias(t *testing.T) {
+ ctx := context.Background()
+ spark, err :=
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+ assert.NoError(t, err)
+ df, err := spark.Sql(ctx, "select * from range(100)")
+ assert.NoError(t, err)
+ df = df.Alias("df")
+ res, er := df.Collect(ctx)
+ assert.NoError(t, er)
+ assert.Equal(t, 100, len(res))
+}
+
+func TestDataFrame_CrossJoin(t *testing.T) {
+ ctx := context.Background()
+ spark, err :=
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+ assert.NoError(t, err)
+ df1, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df2, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df := df1.CrossJoin(df2)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 100, len(res))
+
+ v, e := res[0].Values()
+ assert.NoError(t, e)
+ assert.Equal(t, 2, len(v))
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index cd8ba64..eb285db 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -62,6 +62,15 @@ type DataFrame interface {
FilterByString(condition string) (DataFrame, error)
// Col returns a column by name.
Col(name string) (column.Column, error)
+
+ // Select projects a list of columns from the DataFrame
+ Select(columns ...column.Column) (DataFrame, error)
+ // SelectExpr projects a list of columns from the DataFrame by string
expressions
+ SelectExpr(exprs ...string) (DataFrame, error)
+ // Alias creates a new DataFrame with the specified subquery alias
+ Alias(alias string) DataFrame
+ // CrossJoin joins the current DataFrame with another DataFrame using
the cross product
+ CrossJoin(other DataFrame) DataFrame
}
type RangePartitionColumn struct {
@@ -75,6 +84,63 @@ type dataFrameImpl struct {
relation *proto.Relation // TODO change to proto.Plan?
}
+func (df *dataFrameImpl) SelectExpr(exprs ...string) (DataFrame, error) {
+ expressions := make([]*proto.Expression, 0, len(exprs))
+ for _, expr := range exprs {
+ col := functions.Expr(expr)
+ f, e := col.ToPlan()
+ if e != nil {
+ return nil, e
+ }
+ expressions = append(expressions, f)
+ }
+
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Project{
+ Project: &proto.Project{
+ Input: df.relation,
+ Expressions: expressions,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) Alias(alias string) DataFrame {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SubqueryAlias{
+ SubqueryAlias: &proto.SubqueryAlias{
+ Input: df.relation,
+ Alias: alias,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) CrossJoin(other DataFrame) DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Join{
+ Join: &proto.Join{
+ Left: df.relation,
+ Right: otherDf.relation,
+ JoinType: proto.Join_JOIN_TYPE_CROSS,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
// NewDataFrame creates a new DataFrame
func NewDataFrame(session *sparkSessionImpl, relation *proto.Relation)
DataFrame {
return &dataFrameImpl{
@@ -329,3 +395,27 @@ func (df *dataFrameImpl) Col(name string) (column.Column,
error) {
planId := df.relation.Common.GetPlanId()
return column.NewColumn(column.NewColumnReferenceWithPlanId(name,
planId)), nil
}
+
+func (df *dataFrameImpl) Select(columns ...column.Column) (DataFrame, error) {
+ exprs := make([]*proto.Expression, 0, len(columns))
+ for _, c := range columns {
+ expr, err := c.ToPlan()
+ if err != nil {
+ return nil, err
+ }
+ exprs = append(exprs, expr)
+ }
+
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Project{
+ Project: &proto.Project{
+ Input: df.relation,
+ Expressions: exprs,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel), nil
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]