This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new ffbe433 Add Support for Struct Conversion when reading Arrow data
ffbe433 is described below
commit ffbe433ce4d1669d7de72ab21617779317a13a25
Author: kronsbein <[email protected]>
AuthorDate: Mon Jan 13 13:00:58 2025 +0100
Add Support for Struct Conversion when reading Arrow data
### What changes were proposed in this pull request?
This PR adds support for struct conversion when reading Arrow data
### Why are the changes needed?
Resolves #114
### Does this PR introduce _any_ user-facing change?
Additional functionality when reading Arrow data
### How was this patch tested?
By extending existing test case
[`TestReadArrowRecord`](https://github.com/apache/spark-connect-go/blob/c00cb58be96046e09d41bba0534eeb0417f46e3c/spark/sql/types/arrow_test.go#L74)
Closes #115 from kronsbein/support-struct-conversion-arrow-data.
Authored-by: kronsbein <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/sql_test.go | 42 ++++++++++++++++++++++++++
spark/sql/types/arrow.go | 21 +++++++++++++
spark/sql/types/arrow_test.go | 54 ++++++++++++++++++++++++++++++++++
3 files changed, 117 insertions(+)
diff --git a/internal/tests/integration/sql_test.go
b/internal/tests/integration/sql_test.go
index 28c32d4..88d7182 100644
--- a/internal/tests/integration/sql_test.go
+++ b/internal/tests/integration/sql_test.go
@@ -69,6 +69,48 @@ func TestIntegration_Schema(t *testing.T) {
assert.Equal(t, types.LongType{}, schema.Fields[0].DataType)
}
+func TestIntegration_StructConversion(t *testing.T) {
+ ctx := context.Background()
+ spark, err :=
sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ query := `
+ select named_struct(
+ 'a', 1,
+ 'b', 2,
+ 'c', cast(10.32 as double),
+ 'd', array(1, 2, 3, 4)
+ ) struct_col
+ `
+ df, err := spark.Sql(ctx, query)
+ assert.NoError(t, err)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 1, len(res))
+
+ columnData := res[0].Values()[0]
+ assert.NotNil(t, columnData)
+ structDataMap, ok := columnData.(map[string]any)
+ assert.True(t, ok)
+
+ assert.Contains(t, structDataMap, "a")
+ assert.Contains(t, structDataMap, "b")
+ assert.Contains(t, structDataMap, "c")
+ assert.Contains(t, structDataMap, "d")
+
+ assert.Equal(t, int32(1), structDataMap["a"])
+ assert.Equal(t, int32(2), structDataMap["b"])
+ assert.Equal(t, float64(10.32), structDataMap["c"])
+ arrayData := []any{int32(1), int32(2), int32(3), int32(4)}
+ assert.Equal(t, arrayData, structDataMap["d"])
+
+ schema, err := df.Schema(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, "struct_col", schema.Fields[0].Name)
+}
+
func TestMain(m *testing.M) {
envShouldStartService := os.Getenv("START_SPARK_CONNECT_SERVICE")
shouldStartService := envShouldStartService == "" ||
envShouldStartService == "1"
diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go
index 2921881..47be9aa 100644
--- a/spark/sql/types/arrow.go
+++ b/spark/sql/types/arrow.go
@@ -260,6 +260,27 @@ func readArrayData(t arrow.Type, data arrow.ArrayData)
([]any, error) {
}
buf = append(buf, tmp)
}
+ case arrow.STRUCT:
+ data := array.NewStructData(data)
+ schema := data.DataType().(*arrow.StructType)
+
+ for i := 0; i < data.Len(); i++ {
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ continue
+ }
+ tmp := make(map[string]any)
+
+ for j := range data.NumField() {
+ field := data.Field(j)
+ fieldValues, err :=
readArrayData(field.DataType().ID(), field.Data())
+ if err != nil {
+ return nil, err
+ }
+ tmp[schema.Field(j).Name] = fieldValues[i]
+ }
+ buf = append(buf, tmp)
+ }
default:
return nil, fmt.Errorf("unsupported arrow data type %s",
t.String())
}
diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go
index 7c2b925..8ec9fbb 100644
--- a/spark/sql/types/arrow_test.go
+++ b/spark/sql/types/arrow_test.go
@@ -137,6 +137,22 @@ func TestReadArrowRecord(t *testing.T) {
Name: "map_string_int32",
Type: arrow.MapOf(arrow.BinaryTypes.String,
arrow.PrimitiveTypes.Int32),
},
+ {
+ Name: "struct",
+ Type: arrow.StructOf(
+ arrow.Field{Name: "field1", Type:
arrow.PrimitiveTypes.Int32},
+ arrow.Field{Name: "field2", Type:
arrow.BinaryTypes.String},
+ ),
+ },
+ {
+ Name: "nested_struct",
+ Type: arrow.StructOf(
+ arrow.Field{Name: "field1", Type:
arrow.StructOf(
+ arrow.Field{Name: "nested_field1",
Type: arrow.PrimitiveTypes.Int32},
+ arrow.Field{Name: "nested_field2",
Type: arrow.BinaryTypes.String},
+ )},
+ ),
+ },
}
arrowSchema := arrow.NewSchema(arrowFields, nil)
var buf bytes.Buffer
@@ -224,6 +240,30 @@ func TestReadArrowRecord(t *testing.T) {
mb.KeyBuilder().(*array.StringBuilder).Append("key2")
mb.ItemBuilder().(*array.Int32Builder).Append(2)
+ i++
+ sb := recordBuilder.Field(i).(*array.StructBuilder)
+ sb.Append(true)
+ sb.FieldBuilder(0).(*array.Int32Builder).Append(1)
+ sb.FieldBuilder(1).(*array.StringBuilder).Append("str1")
+
+ sb.Append(true)
+ sb.FieldBuilder(0).(*array.Int32Builder).Append(2)
+ sb.FieldBuilder(1).(*array.StringBuilder).Append("str2")
+
+ i++
+ sb = recordBuilder.Field(i).(*array.StructBuilder)
+ sb.Append(true)
+ nsb := sb.FieldBuilder(0).(*array.StructBuilder)
+ nsb.Append(true)
+ nsb.FieldBuilder(0).(*array.Int32Builder).Append(1)
+ nsb.FieldBuilder(1).(*array.StringBuilder).Append("str1_nested")
+
+ sb.Append(true)
+ nsb = sb.FieldBuilder(0).(*array.StructBuilder)
+ nsb.Append(true)
+ nsb.FieldBuilder(0).(*array.Int32Builder).Append(2)
+ nsb.FieldBuilder(1).(*array.StringBuilder).Append("str2_nested")
+
record := recordBuilder.NewRecord()
defer record.Release()
@@ -239,6 +279,13 @@ func TestReadArrowRecord(t *testing.T) {
arrow.Timestamp(1686981953115000),
arrow.Date64(1686981953117000),
[]any{int64(1), int64(-999231)},
map[any]any{"key1": int32(1)},
+ map[string]any{"field1": int32(1), "field2": "str1"},
+ map[string]any{
+ "field1": map[string]any{
+ "nested_field1": int32(1),
+ "nested_field2": "str1_nested",
+ },
+ },
},
values[0].Values())
assert.Equal(t, []any{
@@ -249,6 +296,13 @@ func TestReadArrowRecord(t *testing.T) {
arrow.Timestamp(1686981953116000),
arrow.Date64(1686981953118000),
[]any{int64(1), int64(2), int64(3)},
map[any]any{"key2": int32(2)},
+ map[string]any{"field1": int32(2), "field2": "str2"},
+ map[string]any{
+ "field1": map[string]any{
+ "nested_field1": int32(2),
+ "nested_field2": "str2_nested",
+ },
+ },
},
values[1].Values())
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]