This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new ffbe433  Add Support for Struct Conversion when reading Arrow data
ffbe433 is described below

commit ffbe433ce4d1669d7de72ab21617779317a13a25
Author: kronsbein <[email protected]>
AuthorDate: Mon Jan 13 13:00:58 2025 +0100

    Add Support for Struct Conversion when reading Arrow data
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for struct conversion when reading Arrow data
    
    ### Why are the changes needed?
    
    Resolves #114
    
    ### Does this PR introduce _any_ user-facing change?
    
    Additional functionality when reading Arrow data
    
    ### How was this patch tested?
    
    By extending existing test case 
[`TestReadArrowRecord`](https://github.com/apache/spark-connect-go/blob/c00cb58be96046e09d41bba0534eeb0417f46e3c/spark/sql/types/arrow_test.go#L74)
    
    Closes #115 from kronsbein/support-struct-conversion-arrow-data.
    
    Authored-by: kronsbein <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/sql_test.go | 42 ++++++++++++++++++++++++++
 spark/sql/types/arrow.go               | 21 +++++++++++++
 spark/sql/types/arrow_test.go          | 54 ++++++++++++++++++++++++++++++++++
 3 files changed, 117 insertions(+)

diff --git a/internal/tests/integration/sql_test.go 
b/internal/tests/integration/sql_test.go
index 28c32d4..88d7182 100644
--- a/internal/tests/integration/sql_test.go
+++ b/internal/tests/integration/sql_test.go
@@ -69,6 +69,48 @@ func TestIntegration_Schema(t *testing.T) {
        assert.Equal(t, types.LongType{}, schema.Fields[0].DataType)
 }
 
+func TestIntegration_StructConversion(t *testing.T) {
+       ctx := context.Background()
+       spark, err := 
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       query := `
+               select named_struct(
+                       'a', 1,
+                       'b', 2,
+                       'c', cast(10.32 as double),
+                       'd', array(1, 2, 3, 4)
+               ) struct_col
+       `
+       df, err := spark.Sql(ctx, query)
+       assert.NoError(t, err)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 1, len(res))
+
+       columnData := res[0].Values()[0]
+       assert.NotNil(t, columnData)
+       structDataMap, ok := columnData.(map[string]any)
+       assert.True(t, ok)
+
+       assert.Contains(t, structDataMap, "a")
+       assert.Contains(t, structDataMap, "b")
+       assert.Contains(t, structDataMap, "c")
+       assert.Contains(t, structDataMap, "d")
+
+       assert.Equal(t, int32(1), structDataMap["a"])
+       assert.Equal(t, int32(2), structDataMap["b"])
+       assert.Equal(t, float64(10.32), structDataMap["c"])
+       arrayData := []any{int32(1), int32(2), int32(3), int32(4)}
+       assert.Equal(t, arrayData, structDataMap["d"])
+
+       schema, err := df.Schema(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, "struct_col", schema.Fields[0].Name)
+}
+
 func TestMain(m *testing.M) {
        envShouldStartService := os.Getenv("START_SPARK_CONNECT_SERVICE")
        shouldStartService := envShouldStartService == "" || 
envShouldStartService == "1"
diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go
index 2921881..47be9aa 100644
--- a/spark/sql/types/arrow.go
+++ b/spark/sql/types/arrow.go
@@ -260,6 +260,27 @@ func readArrayData(t arrow.Type, data arrow.ArrayData) 
([]any, error) {
                        }
                        buf = append(buf, tmp)
                }
+       case arrow.STRUCT:
+               data := array.NewStructData(data)
+               schema := data.DataType().(*arrow.StructType)
+
+               for i := 0; i < data.Len(); i++ {
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                               continue
+                       }
+                       tmp := make(map[string]any)
+
+                       for j := range data.NumField() {
+                               field := data.Field(j)
+                               fieldValues, err := 
readArrayData(field.DataType().ID(), field.Data())
+                               if err != nil {
+                                       return nil, err
+                               }
+                               tmp[schema.Field(j).Name] = fieldValues[i]
+                       }
+                       buf = append(buf, tmp)
+               }
        default:
                return nil, fmt.Errorf("unsupported arrow data type %s", 
t.String())
        }
diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go
index 7c2b925..8ec9fbb 100644
--- a/spark/sql/types/arrow_test.go
+++ b/spark/sql/types/arrow_test.go
@@ -137,6 +137,22 @@ func TestReadArrowRecord(t *testing.T) {
                        Name: "map_string_int32",
                        Type: arrow.MapOf(arrow.BinaryTypes.String, 
arrow.PrimitiveTypes.Int32),
                },
+               {
+                       Name: "struct",
+                       Type: arrow.StructOf(
+                               arrow.Field{Name: "field1", Type: 
arrow.PrimitiveTypes.Int32},
+                               arrow.Field{Name: "field2", Type: 
arrow.BinaryTypes.String},
+                       ),
+               },
+               {
+                       Name: "nested_struct",
+                       Type: arrow.StructOf(
+                               arrow.Field{Name: "field1", Type: 
arrow.StructOf(
+                                       arrow.Field{Name: "nested_field1", 
Type: arrow.PrimitiveTypes.Int32},
+                                       arrow.Field{Name: "nested_field2", 
Type: arrow.BinaryTypes.String},
+                               )},
+                       ),
+               },
        }
        arrowSchema := arrow.NewSchema(arrowFields, nil)
        var buf bytes.Buffer
@@ -224,6 +240,30 @@ func TestReadArrowRecord(t *testing.T) {
        mb.KeyBuilder().(*array.StringBuilder).Append("key2")
        mb.ItemBuilder().(*array.Int32Builder).Append(2)
 
+       i++
+       sb := recordBuilder.Field(i).(*array.StructBuilder)
+       sb.Append(true)
+       sb.FieldBuilder(0).(*array.Int32Builder).Append(1)
+       sb.FieldBuilder(1).(*array.StringBuilder).Append("str1")
+
+       sb.Append(true)
+       sb.FieldBuilder(0).(*array.Int32Builder).Append(2)
+       sb.FieldBuilder(1).(*array.StringBuilder).Append("str2")
+
+       i++
+       sb = recordBuilder.Field(i).(*array.StructBuilder)
+       sb.Append(true)
+       nsb := sb.FieldBuilder(0).(*array.StructBuilder)
+       nsb.Append(true)
+       nsb.FieldBuilder(0).(*array.Int32Builder).Append(1)
+       nsb.FieldBuilder(1).(*array.StringBuilder).Append("str1_nested")
+
+       sb.Append(true)
+       nsb = sb.FieldBuilder(0).(*array.StructBuilder)
+       nsb.Append(true)
+       nsb.FieldBuilder(0).(*array.Int32Builder).Append(2)
+       nsb.FieldBuilder(1).(*array.StringBuilder).Append("str2_nested")
+
        record := recordBuilder.NewRecord()
        defer record.Release()
 
@@ -239,6 +279,13 @@ func TestReadArrowRecord(t *testing.T) {
                arrow.Timestamp(1686981953115000), 
arrow.Date64(1686981953117000),
                []any{int64(1), int64(-999231)},
                map[any]any{"key1": int32(1)},
+               map[string]any{"field1": int32(1), "field2": "str1"},
+               map[string]any{
+                       "field1": map[string]any{
+                               "nested_field1": int32(1),
+                               "nested_field2": "str1_nested",
+                       },
+               },
        },
                values[0].Values())
        assert.Equal(t, []any{
@@ -249,6 +296,13 @@ func TestReadArrowRecord(t *testing.T) {
                arrow.Timestamp(1686981953116000), 
arrow.Date64(1686981953118000),
                []any{int64(1), int64(2), int64(3)},
                map[any]any{"key2": int32(2)},
+               map[string]any{"field1": int32(2), "field2": "str2"},
+               map[string]any{
+                       "field1": map[string]any{
+                               "nested_field1": int32(2),
+                               "nested_field2": "str2_nested",
+                       },
+               },
        },
                values[1].Values())
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to