This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new a8ea83b  [#58] Initial Support for `GroupBy`
a8ea83b is described below

commit a8ea83bc0005a57da6cdbeb75ddd604828443af1
Author: Martin Grund <[email protected]>
AuthorDate: Wed Aug 28 21:18:55 2024 +0200

    [#58] Initial Support for `GroupBy`
    
    ### What changes were proposed in this pull request?
    This patch is adding the initial version of `GroupBy` that can be used to 
run aggregatequeries. It allows to run the following code:
    
    ```
    df, _ := spark.Sql(ctx, "select * from range(10)")
    df, _ = df.GroupBy("id").Min() // Will return 0
    ```
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    GroupBy support
    
    ### How was this patch tested?
    Unit and Integration tests.
    
    Closes #60 from grundprinzip/df_group_by.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go       |  19 +++
 .../tests/integration/helper.go                    |  20 ++-
 spark/sparkerrors/errors.go                        |   4 +
 spark/sql/column/column.go                         |   4 +
 spark/sql/column/column_test.go                    |  14 ++
 spark/sql/dataframe.go                             |  13 ++
 spark/sql/dataframe_test.go                        |  32 ++++
 spark/sql/functions/generated.go                   |   7 +-
 spark/sql/group.go                                 | 174 +++++++++++++++++++++
 spark/sql/group_test.go                            |  99 ++++++++++++
 spark/sql/types/datatype.go                        |  57 +++++++
 11 files changed, 437 insertions(+), 6 deletions(-)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index 9fe2e3d..576422b 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -99,3 +99,22 @@ func TestDataFrame_CrossJoin(t *testing.T) {
        assert.NoError(t, e)
        assert.Equal(t, 2, len(v))
 }
+
+func TestDataFrame_GroupBy(t *testing.T) {
+       ctx, spark := connect()
+       src, _ := spark.Sql(ctx, "select 'a' as a, 1 as b from range(10)")
+       df, _ := src.GroupBy(functions.Col("a")).Agg(ctx, 
functions.Sum(functions.Col("b")))
+
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 1, len(res))
+
+       df, err = src.GroupBy(functions.Col("a")).Count(ctx)
+       assert.NoError(t, err)
+       res, err = df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 1, len(res))
+       vals, _ := res[0].Values()
+       assert.Equal(t, "a", vals[0])
+       assert.Equal(t, int64(10), vals[1])
+}
diff --git a/spark/sql/dataframe_test.go b/internal/tests/integration/helper.go
similarity index 67%
copy from spark/sql/dataframe_test.go
copy to internal/tests/integration/helper.go
index 831ad94..a1f7f28 100644
--- a/spark/sql/dataframe_test.go
+++ b/internal/tests/integration/helper.go
@@ -1,4 +1,3 @@
-//
 // Licensed to the Apache Software Foundation (ASF) under one or more
 // contributor license agreements.  See the NOTICE file distributed with
 // this work for additional information regarding copyright ownership.
@@ -6,7 +5,7 @@
 // (the "License"); you may not use this file except in compliance with
 // the License.  You may obtain a copy of the License at
 //
-//    http://www.apache.org/licenses/LICENSE-2.0
+//     http://www.apache.org/licenses/LICENSE-2.0
 //
 // Unless required by applicable law or agreed to in writing, software
 // distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,4 +13,19 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package sql
+package integration
+
+import (
+       "context"
+
+       "github.com/apache/spark-connect-go/v35/spark/sql"
+)
+
+func connect() (context.Context, sql.SparkSession) {
+       ctx := context.Background()
+       spark, err := 
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+       if err != nil {
+               panic(err)
+       }
+       return ctx, spark
+}
diff --git a/spark/sparkerrors/errors.go b/spark/sparkerrors/errors.go
index c2665c8..5a13aa1 100644
--- a/spark/sparkerrors/errors.go
+++ b/spark/sparkerrors/errors.go
@@ -49,6 +49,10 @@ func WithString(err error, errMsg string) error {
        return &wrappedError{cause: errors.Wrap(err, 1), errorType: 
errors.New(errMsg)}
 }
 
+func WithStringf(err error, errMsg string, params ...any) error {
+       return &wrappedError{cause: errors.Wrap(err, 1), errorType: 
fmt.Errorf(errMsg, params...)}
+}
+
 type errorType error
 
 var (
diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go
index aae8fd1..28bf197 100644
--- a/spark/sql/column/column.go
+++ b/spark/sql/column/column.go
@@ -85,6 +85,10 @@ func (c Column) Asc() Column {
        })
 }
 
+func (c Column) Alias(alias string) Column {
+       return NewColumn(NewColumnAlias(alias, c.expr))
+}
+
 func NewColumn(expr expression) Column {
        return Column{
                expr: expr,
diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go
index b823921..ba66a93 100644
--- a/spark/sql/column/column_test.go
+++ b/spark/sql/column/column_test.go
@@ -92,11 +92,25 @@ func TestColumnFunctions(t *testing.T) {
        col1 := NewColumn(NewColumnReference("col1"))
        col2 := NewColumn(NewColumnReference("col2"))
 
+       col1Plan, _ := col1.ToProto(context.Background())
+
        tests := []struct {
                name string
                arg  Column
                want *proto.Expression
        }{
+               {
+                       name: "TestColumnAlias",
+                       arg:  NewColumn(NewColumnAlias("alias", col1.expr)),
+                       want: &proto.Expression{
+                               ExprType: &proto.Expression_Alias_{
+                                       Alias: &proto.Expression_Alias{
+                                               Expr: col1Plan,
+                                               Name: []string{"alias"},
+                                       },
+                               },
+                       },
+               },
                {
                        name: "TestNewUnresolvedFunction",
                        arg:  NewColumn(NewUnresolvedFunction("id", nil, 
false)),
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index bf44ab1..63215ab 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -69,6 +69,9 @@ type DataFrame interface {
        Alias(ctx context.Context, alias string) DataFrame
        // CrossJoin joins the current DataFrame with another DataFrame using 
the cross product
        CrossJoin(ctx context.Context, other DataFrame) DataFrame
+       // GroupBy groups the DataFrame by the spcified columns so that the 
aggregation
+       // can be performed on them. See GroupedData for all the available 
aggregate functions.
+       GroupBy(cols ...column.Convertible) *GroupedData
 }
 
 // dataFrameImpl is an implementation of DataFrame interface.
@@ -396,3 +399,13 @@ func (df *dataFrameImpl) Select(ctx context.Context, 
columns ...column.Convertib
        }
        return NewDataFrame(df.session, rel), nil
 }
+
+// GroupBy groups the DataFrame by the specified columns so that aggregation
+// can be performed on them. See GroupedData for all the available aggregate 
functions.
+func (df *dataFrameImpl) GroupBy(cols ...column.Convertible) *GroupedData {
+       return &GroupedData{
+               df:           *df,
+               groupingCols: cols,
+               groupType:    "groupby",
+       }
+}
diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go
index 831ad94..39145b0 100644
--- a/spark/sql/dataframe_test.go
+++ b/spark/sql/dataframe_test.go
@@ -15,3 +15,35 @@
 // limitations under the License.
 
 package sql
+
+import (
+       "context"
+       "testing"
+
+       proto "github.com/apache/spark-connect-go/v35/internal/generated"
+       "github.com/apache/spark-connect-go/v35/spark/sql/functions"
+       "github.com/stretchr/testify/assert"
+)
+
+func TestDataFrameImpl_GroupBy(t *testing.T) {
+       ctx := context.Background()
+       rel := &proto.Relation{
+               RelType: &proto.Relation_Range{
+                       Range: &proto.Range{
+                               End:  10,
+                               Step: 1,
+                       },
+               },
+       }
+       df := NewDataFrame(nil, rel)
+       gd := df.GroupBy(functions.Col("id"))
+       assert.NotNil(t, gd)
+
+       assert.Equal(t, gd.groupType, "groupby")
+
+       df, err := gd.Agg(ctx, functions.Count(functions.Lit(1)))
+       assert.Nil(t, err)
+       impl := df.(*dataFrameImpl)
+       assert.NotNil(t, impl)
+       assert.IsType(t, impl.relation.RelType, &proto.Relation_Aggregate{})
+}
diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go
index b468cda..cb9421e 100644
--- a/spark/sql/functions/generated.go
+++ b/spark/sql/functions/generated.go
@@ -2764,8 +2764,8 @@ func MakeTimestamp(years column.Column, months 
column.Column, days column.Column
 func MakeTimestampLtz(years column.Column, months column.Column, days 
column.Column,
        hours column.Column, mins column.Column, secs column.Column, timezone 
column.Column,
 ) column.Column {
-       return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("make_timestamp_ltz", 
years,
-               months, days, hours, mins, secs, timezone))
+       return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("make_timestamp_ltz",
+               years, months, days, hours, mins, secs, timezone))
 }
 
 // MakeTimestampNtz - Create local date-time from years, months, days, hours, 
mins, secs fields.
@@ -2776,7 +2776,8 @@ func MakeTimestampLtz(years column.Column, months 
column.Column, days column.Col
 func MakeTimestampNtz(years column.Column, months column.Column, days 
column.Column,
        hours column.Column, mins column.Column, secs column.Column,
 ) column.Column {
-       return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("make_timestamp_ntz", 
years, months, days, hours, mins, secs))
+       return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("make_timestamp_ntz",
+               years, months, days, hours, mins, secs))
 }
 
 // MakeYmInterval - Make year-month interval from years, months.
diff --git a/spark/sql/group.go b/spark/sql/group.go
new file mode 100644
index 0000000..4734225
--- /dev/null
+++ b/spark/sql/group.go
@@ -0,0 +1,174 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sql
+
+import (
+       "context"
+
+       proto "github.com/apache/spark-connect-go/v35/internal/generated"
+       "github.com/apache/spark-connect-go/v35/spark/sparkerrors"
+       "github.com/apache/spark-connect-go/v35/spark/sql/column"
+       "github.com/apache/spark-connect-go/v35/spark/sql/functions"
+)
+
+type GroupedData struct {
+       df           dataFrameImpl
+       groupType    string
+       groupingCols []column.Convertible
+       pivotValues  []any
+       // groupingSets [][]column.Column
+}
+
+// Agg compute aggregates and returns the result as a DataFrame. The aggegrate 
expressions
+// are passed as column.Column arguments.
+func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Column) 
(DataFrame, error) {
+       if len(exprs) == 0 {
+               return nil, 
sparkerrors.WithString(sparkerrors.InvalidInputError, "exprs should not be 
empty")
+       }
+
+       agg := &proto.Aggregate{
+               Input: gd.df.relation,
+       }
+
+       // Add all grouping and aggregate expressions.
+       agg.GroupingExpressions = make([]*proto.Expression, 
len(gd.groupingCols))
+       for i, col := range gd.groupingCols {
+               exp, err := col.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               agg.GroupingExpressions[i] = exp
+       }
+
+       agg.AggregateExpressions = make([]*proto.Expression, len(exprs))
+       for i, expr := range exprs {
+               exp, err := expr.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               agg.AggregateExpressions[i] = exp
+       }
+
+       // Apply the groupType
+       switch gd.groupType {
+       case "pivot":
+               agg.GroupType = proto.Aggregate_GROUP_TYPE_PIVOT
+               // Apply all pivot behavior and convert columns into literals.
+               if len(gd.pivotValues) == 0 {
+                       return nil, 
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivotValues should not 
be empty")
+               }
+               agg.Pivot = &proto.Aggregate_Pivot{
+                       Values: make([]*proto.Expression_Literal, 
len(gd.pivotValues)),
+               }
+               for i, v := range gd.pivotValues {
+                       exp, err := column.NewLiteral(v).ToProto(ctx)
+                       if err != nil {
+                               return nil, err
+                       }
+                       agg.Pivot.Values[i] = exp.GetLiteral()
+               }
+       case "groupby":
+               agg.GroupType = proto.Aggregate_GROUP_TYPE_GROUPBY
+       case "rollup":
+               agg.GroupType = proto.Aggregate_GROUP_TYPE_ROLLUP
+       case "cube":
+               agg.GroupType = proto.Aggregate_GROUP_TYPE_CUBE
+       }
+
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Aggregate{
+                       Aggregate: agg,
+               },
+       }
+       return NewDataFrame(gd.df.session, rel), nil
+}
+
+func (gd *GroupedData) numericAgg(ctx context.Context, name string, cols 
...string) (DataFrame, error) {
+       schema, err := gd.df.Schema(ctx)
+       if err != nil {
+               return nil, err
+       }
+
+       // Find all numeric cols in the schema:
+       numericCols := make([]string, 0)
+       for _, field := range schema.Fields {
+               if field.DataType.IsNumeric() {
+                       numericCols = append(numericCols, field.Name)
+               }
+       }
+
+       aggCols := cols
+       if len(cols) > 0 {
+               invalidCols := make([]string, 0)
+               for _, col := range cols {
+                       found := false
+                       for _, nc := range numericCols {
+                               if col == nc {
+                                       found = true
+                               }
+                       }
+                       if !found {
+                               invalidCols = append(invalidCols, col)
+                       }
+               }
+               if len(invalidCols) > 0 {
+                       return nil, 
sparkerrors.WithStringf(sparkerrors.InvalidInputError,
+                               "columns %v are not numeric", invalidCols)
+               }
+       } else {
+               aggCols = numericCols
+       }
+
+       finalColumns := make([]column.Column, len(aggCols))
+       for i, col := range aggCols {
+               finalColumns[i] = 
column.NewColumn(column.NewUnresolvedFunctionWithColumns(name, 
functions.Col(col)))
+       }
+       return gd.Agg(ctx, finalColumns...)
+}
+
+// Min Computes the min value for each numeric column for each group.
+func (gd *GroupedData) Min(ctx context.Context, cols ...string) (DataFrame, 
error) {
+       return gd.numericAgg(ctx, "min", cols...)
+}
+
+// Max Computes the max value for each numeric column for each group.
+func (gd *GroupedData) Max(ctx context.Context, cols ...string) (DataFrame, 
error) {
+       return gd.numericAgg(ctx, "max", cols...)
+}
+
+// Avg Computes the avg value for each numeric column for each group.
+func (gd *GroupedData) Avg(ctx context.Context, cols ...string) (DataFrame, 
error) {
+       return gd.numericAgg(ctx, "avg", cols...)
+}
+
+// Sum Computes the sum value for each numeric column for each group.
+func (gd *GroupedData) Sum(ctx context.Context, cols ...string) (DataFrame, 
error) {
+       return gd.numericAgg(ctx, "sum", cols...)
+}
+
+// Count Computes the count value for each group.
+func (gd *GroupedData) Count(ctx context.Context) (DataFrame, error) {
+       return gd.Agg(ctx, functions.Count(functions.Lit(1)).Alias("count"))
+}
+
+// Mean Computes the average value for each numeric column for each group.
+func (gd *GroupedData) Mean(ctx context.Context, cols ...string) (DataFrame, 
error) {
+       return gd.Avg(ctx, cols...)
+}
diff --git a/spark/sql/group_test.go b/spark/sql/group_test.go
new file mode 100644
index 0000000..8cdd908
--- /dev/null
+++ b/spark/sql/group_test.go
@@ -0,0 +1,99 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sql
+
+import (
+       "context"
+       "testing"
+
+       proto "github.com/apache/spark-connect-go/v35/internal/generated"
+       "github.com/apache/spark-connect-go/v35/spark/client"
+       "github.com/apache/spark-connect-go/v35/spark/client/testutils"
+       "github.com/apache/spark-connect-go/v35/spark/mocks"
+       "github.com/stretchr/testify/assert"
+)
+
+var sampleDataFrame = dataFrameImpl{session: nil, relation: &proto.Relation{
+       RelType: &proto.Relation_Range{
+               Range: &proto.Range{
+                       End:  10,
+                       Step: 1,
+               },
+       },
+}}
+
+func TestGroupedData_Agg(t *testing.T) {
+       ctx := context.Background()
+       c := client.NewSparkExecutorFromClient(
+               testutils.NewConnectServiceClientMock(nil, 
mocks.AnalyzePlanResponse, nil, nil), nil, mocks.MockSessionId)
+       session := sparkSessionImpl{sessionId: mocks.MockSessionId, client: c}
+       sampleDataFrame.session = &session
+
+       gd := GroupedData{
+               groupType: "groupby",
+               df:        sampleDataFrame,
+       }
+
+       // Should not be able to group by a non-existing column
+       _, err := gd.Min(ctx, "nonExistingColumn")
+       assert.Error(t, err)
+
+       // Group by an existing column should work
+       df, err := gd.Min(ctx, "col0")
+       assert.NoError(t, err)
+       assert.IsType(t, df.(*dataFrameImpl).relation.RelType, 
&proto.Relation_Aggregate{})
+       assert.Equal(t, "min", 
df.(*dataFrameImpl).relation.GetAggregate().GetAggregateExpressions()[0].GetUnresolvedFunction().FunctionName)
+
+       // Group by an existing column should work
+       df, err = gd.Max(ctx, "col0")
+       assert.NoError(t, err)
+       assert.IsType(t, df.(*dataFrameImpl).relation.RelType, 
&proto.Relation_Aggregate{})
+       assert.Equal(t, "max", 
df.(*dataFrameImpl).relation.GetAggregate().GetAggregateExpressions()[0].GetUnresolvedFunction().FunctionName)
+
+       df, err = gd.Sum(ctx, "col0")
+       assert.NoError(t, err)
+       assert.IsType(t, df.(*dataFrameImpl).relation.RelType, 
&proto.Relation_Aggregate{})
+       assert.Equal(t, "sum", 
df.(*dataFrameImpl).relation.GetAggregate().GetAggregateExpressions()[0].GetUnresolvedFunction().FunctionName)
+
+       df, err = gd.Avg(ctx, "col0")
+       assert.NoError(t, err)
+       assert.IsType(t, df.(*dataFrameImpl).relation.RelType, 
&proto.Relation_Aggregate{})
+       assert.Equal(t, "avg", 
df.(*dataFrameImpl).relation.GetAggregate().GetAggregateExpressions()[0].GetUnresolvedFunction().FunctionName)
+
+       // Group by no column should pick all numeric columns
+       df, err = gd.Min(ctx)
+       assert.NoError(t, err)
+       assert.IsType(t, df.(*dataFrameImpl).relation.RelType, 
&proto.Relation_Aggregate{})
+       assert.Len(t, 
df.(*dataFrameImpl).relation.GetAggregate().GetAggregateExpressions(), 1)
+}
+
+func TestGroupedData_Count(t *testing.T) {
+       ctx := context.Background()
+       c := client.NewSparkExecutorFromClient(
+               testutils.NewConnectServiceClientMock(nil, 
mocks.AnalyzePlanResponse, nil, nil), nil, mocks.MockSessionId)
+       session := sparkSessionImpl{sessionId: mocks.MockSessionId, client: c}
+       sampleDataFrame.session = &session
+
+       gd := GroupedData{
+               groupType: "groupby",
+               df:        sampleDataFrame,
+       }
+
+       df, err := gd.Count(ctx)
+       assert.NoError(t, err)
+       assert.IsType(t, df.(*dataFrameImpl).relation.RelType, 
&proto.Relation_Aggregate{})
+       assert.Equal(t, []string{"count"}, 
df.(*dataFrameImpl).relation.GetAggregate().GetAggregateExpressions()[0].GetAlias().Name)
+}
diff --git a/spark/sql/types/datatype.go b/spark/sql/types/datatype.go
index 46c3e57..971dc90 100644
--- a/spark/sql/types/datatype.go
+++ b/spark/sql/types/datatype.go
@@ -23,6 +23,7 @@ import (
 
 type DataType interface {
        TypeName() string
+       IsNumeric() bool
 }
 
 type BooleanType struct{}
@@ -31,8 +32,16 @@ func (t BooleanType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t BooleanType) IsNumeric() bool {
+       return false
+}
+
 type ByteType struct{}
 
+func (t ByteType) IsNumeric() bool {
+       return true
+}
+
 func (t ByteType) TypeName() string {
        return getDataTypeName(t)
 }
@@ -43,66 +52,110 @@ func (t ShortType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t ShortType) IsNumeric() bool {
+       return true
+}
+
 type IntegerType struct{}
 
 func (t IntegerType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t IntegerType) IsNumeric() bool {
+       return true
+}
+
 type LongType struct{}
 
 func (t LongType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t LongType) IsNumeric() bool {
+       return true
+}
+
 type FloatType struct{}
 
 func (t FloatType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t FloatType) IsNumeric() bool {
+       return true
+}
+
 type DoubleType struct{}
 
 func (t DoubleType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t DoubleType) IsNumeric() bool {
+       return true
+}
+
 type DecimalType struct{}
 
 func (t DecimalType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t DecimalType) IsNumeric() bool {
+       return true
+}
+
 type StringType struct{}
 
 func (t StringType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t StringType) IsNumeric() bool {
+       return false
+}
+
 type BinaryType struct{}
 
 func (t BinaryType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t BinaryType) IsNumeric() bool {
+       return false
+}
+
 type TimestampType struct{}
 
 func (t TimestampType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t TimestampType) IsNumeric() bool {
+       return false
+}
+
 type TimestampNtzType struct{}
 
 func (t TimestampNtzType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t TimestampNtzType) IsNumeric() bool {
+       return false
+}
+
 type DateType struct{}
 
 func (t DateType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t DateType) IsNumeric() bool {
+       return false
+}
+
 type UnsupportedType struct {
        TypeInfo any
 }
@@ -111,6 +164,10 @@ func (t UnsupportedType) TypeName() string {
        return getDataTypeName(t)
 }
 
+func (t UnsupportedType) IsNumeric() bool {
+       return false
+}
+
 func getDataTypeName(dataType DataType) string {
        typeName := fmt.Sprintf("%T", dataType)
        nonQualifiedTypeName := strings.Split(typeName, ".")[1]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to