This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 8432dd0  Cleanup row type and Arrow conversion
8432dd0 is described below

commit 8432dd0ed2849aa4add905962b139470e7806610
Author: Martin Grund <[email protected]>
AuthorDate: Thu Oct 3 04:13:08 2024 -0700

    Cleanup row type and Arrow conversion
    
    ### What changes were proposed in this pull request?
    This patch removes some of the unnecessary indirections of the previous 
implementation of the Row type and addss support for deserializing LIST types 
from Arrow as well.
    
    ### Why are the changes needed?
    Compatibility / Ease of Use
    
    ### Does this PR introduce _any_ user-facing change?
    Now the `Row` type is closer to the PySpark one avoiding the unnecessary 
extraction to `row.values()` to access the values by offset or name.
    
    ### How was this patch tested?
    Added Unit tests
    
    Closes #72 from grundprinzip/cleanup_row_type.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 cmd/spark-connect-example-spark-session/main.go |   2 +-
 internal/tests/integration/dataframe_test.go    | 104 ++++-----
 spark/sql/dataframe.go                          |  69 ++----
 spark/sql/row.go                                |  48 -----
 spark/sql/sparksession_test.go                  |   2 +-
 spark/sql/types/arrow.go                        | 266 ++++++++++++++----------
 spark/sql/types/arrow_test.go                   |  47 ++++-
 spark/sql/{row_test.go => types/row.go}         |  51 +++--
 8 files changed, 282 insertions(+), 307 deletions(-)

diff --git a/cmd/spark-connect-example-spark-session/main.go 
b/cmd/spark-connect-example-spark-session/main.go
index 6a1fe6a..516d87b 100644
--- a/cmd/spark-connect-example-spark-session/main.go
+++ b/cmd/spark-connect-example-spark-session/main.go
@@ -102,7 +102,7 @@ func main() {
                log.Fatalf("Failed: %s", err)
        }
 
-       schema, err = rows[0].Schema()
+       schema, err = df.Schema(ctx)
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index 61ebcdf..bb9e2aa 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -45,9 +45,7 @@ func TestDataFrame_Select(t *testing.T) {
        assert.Equal(t, 100, len(res))
 
        rowZero := res[0]
-       vals, err := rowZero.Values()
-       assert.NoError(t, err)
-       assert.Equal(t, 2, len(vals))
+       assert.Equal(t, 2, rowZero.Len())
 
        df, err = spark.Sql(ctx, "select * from range(100)")
        assert.NoError(t, err)
@@ -69,9 +67,7 @@ func TestDataFrame_SelectExpr(t *testing.T) {
        assert.Equal(t, 100, len(res))
 
        row_zero := res[0]
-       vals, err := row_zero.Values()
-       assert.NoError(t, err)
-       assert.Equal(t, 3, len(vals))
+       assert.Equal(t, 3, row_zero.Len())
 }
 
 func TestDataFrame_Alias(t *testing.T) {
@@ -98,10 +94,7 @@ func TestDataFrame_CrossJoin(t *testing.T) {
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Equal(t, 100, len(res))
-
-       v, e := res[0].Values()
-       assert.NoError(t, e)
-       assert.Equal(t, 2, len(v))
+       assert.Equal(t, 2, res[0].Len())
 }
 
 func TestDataFrame_GroupBy(t *testing.T) {
@@ -118,9 +111,8 @@ func TestDataFrame_GroupBy(t *testing.T) {
        res, err = df.Collect(ctx)
        assert.NoError(t, err)
        assert.Equal(t, 1, len(res))
-       vals, _ := res[0].Values()
-       assert.Equal(t, "a", vals[0])
-       assert.Equal(t, int64(10), vals[1])
+       assert.Equal(t, "a", res[0].At(0))
+       assert.Equal(t, int64(10), res[0].At(1))
 }
 
 func TestDataFrame_Count(t *testing.T) {
@@ -172,13 +164,10 @@ func TestSparkSession_CreateDataFrameWithSchema(t 
*testing.T) {
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Equal(t, 2, len(res))
-
-       row1, err := res[0].Values()
-       assert.NoError(t, err)
-       assert.Len(t, row1, 3)
-       assert.Equal(t, int32(1), row1[0])
-       assert.Equal(t, 1.1, row1[1])
-       assert.Equal(t, "a", row1[2])
+       assert.Equal(t, 3, res[0].Len())
+       assert.Equal(t, int32(1), res[0].At(0))
+       assert.Equal(t, 1.1, res[0].At(1))
+       assert.Equal(t, "a", res[0].At(2))
 }
 
 func TestDataFrame_Corr(t *testing.T) {
@@ -226,10 +215,8 @@ func TestDataFrame_WithColumn(t *testing.T) {
        assert.Equal(t, 10, len(res))
        // Check the values of the new column
        for _, row := range res {
-               vals, err := row.Values()
-               assert.NoError(t, err)
-               assert.Equal(t, 2, len(vals))
-               assert.Equal(t, int64(1), vals[1])
+               assert.Equal(t, 2, row.Len())
+               assert.Equal(t, int64(1), row.At(1))
        }
 }
 
@@ -247,11 +234,9 @@ func TestDataFrame_WithColumns(t *testing.T) {
        assert.Equal(t, 10, len(res))
        // Check the values of the new columns
        for _, row := range res {
-               vals, err := row.Values()
-               assert.NoError(t, err)
-               assert.Equal(t, 3, len(vals))
-               assert.Equal(t, int64(1), vals[1], "%v", vals)
-               assert.Equal(t, int64(2), vals[2], "%v", vals)
+               assert.Equal(t, 3, row.Len())
+               assert.Equal(t, int64(1), row.At(1))
+               assert.Equal(t, int64(2), row.At(2))
        }
 }
 
@@ -291,10 +276,8 @@ func TestDataFrame_WithColumnRenamed(t *testing.T) {
        assert.Equal(t, 10, len(res))
        // Check the values of the new column
        for i, row := range res {
-               vals, err := row.Values()
-               assert.NoError(t, err)
-               assert.Equal(t, 1, len(vals))
-               assert.Equal(t, int64(i), vals[0])
+               assert.Equal(t, 1, row.Len())
+               assert.Equal(t, int64(i), row.At(0))
        }
 
        // Test that renaming a non-existing column does not change anything.
@@ -397,12 +380,10 @@ func TestDataFrame_DropDuplicates(t *testing.T) {
        rows, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Equal(t, 2, len(rows))
-       vals, _ := rows[0].Values()
-       assert.Equal(t, "Alice", vals[0])
-       assert.Equal(t, int32(5), vals[1])
-       vals, _ = rows[1].Values()
-       assert.Equal(t, "Alice", vals[0])
-       assert.Equal(t, int32(10), vals[1])
+       assert.Equal(t, "Alice", rows[0].At(0))
+       assert.Equal(t, int32(5), rows[0].At(1))
+       assert.Equal(t, "Alice", rows[1].At(0))
+       assert.Equal(t, int32(10), rows[1].At(1))
 
        // Test drop duplicates with column names
        df, err = df.DropDuplicates(ctx, "name")
@@ -536,17 +517,13 @@ func TestDataFrame_Sort(t *testing.T) {
        assert.NoError(t, err)
        res, err := df.Head(ctx, 1)
        assert.NoError(t, err)
-       vals, err := res[0].Values()
-       assert.NoError(t, err)
-       assert.Equal(t, int64(9), vals[0])
+       assert.Equal(t, int64(9), res[0].At(0))
 
        df, err = src.Sort(ctx, functions.Col("id").Asc())
        assert.NoError(t, err)
        res, err = df.Head(ctx, 1)
        assert.NoError(t, err)
-       vals, err = res[0].Values()
-       assert.NoError(t, err)
-       assert.Equal(t, int64(0), vals[0])
+       assert.Equal(t, int64(0), res[0].At(0))
 }
 
 func TestDataFrame_Join(t *testing.T) {
@@ -593,10 +570,8 @@ func TestDataFrame_Summary(t *testing.T) {
        res, err := df.Summary(ctx, "count", "stddev").Collect(ctx)
        assert.NoError(t, err)
        assert.Len(t, res, 2)
-       v, err := res[0].Values()
-       assert.NoError(t, err)
-       assert.Equal(t, "count", v[0])
-       assert.Len(t, v, 4)
+       assert.Equal(t, "count", res[0].At(0))
+       assert.Equal(t, 4, res[0].Len())
 }
 
 func TestDataFrame_Pivot(t *testing.T) {
@@ -658,9 +633,7 @@ func TestDataFrame_First(t *testing.T) {
        assert.NoError(t, err)
        row, err := df.First(ctx)
        assert.NoError(t, err)
-       vals, err := row.Values()
-       assert.NoError(t, err)
-       assert.Equal(t, int64(0), vals[0])
+       assert.Equal(t, int64(0), row.At(0))
 }
 
 func TestDataFrame_Distinct(t *testing.T) {
@@ -690,12 +663,10 @@ func TestDataFrame_CrossTab(t *testing.T) {
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Len(t, res, 3)
-       v, err := res[0].Values()
-       assert.NoError(t, err)
-       assert.Equal(t, "1", v[0])
-       assert.Equal(t, int64(0), v[1])
-       assert.Equal(t, int64(2), v[2])
-       assert.Equal(t, int64(0), v[3])
+       assert.Equal(t, "1", res[0].At(0))
+       assert.Equal(t, int64(0), res[0].At(1))
+       assert.Equal(t, int64(2), res[0].At(2))
+       assert.Equal(t, int64(0), res[0].At(3))
 }
 
 func TestDataFrame_SameSemantics(t *testing.T) {
@@ -717,12 +688,11 @@ func TestDataFrame_SemanticHash(t *testing.T) {
        assert.NotEmpty(t, hash)
 }
 
-// DISABLED: Because we cannot parse LIST types
-//func TestDataFrame_FreqItems(t *testing.T) {
-//     ctx, spark := connect()
-//     df, err := spark.Sql(ctx, "select * from range(10)")
-//     assert.NoError(t, err)
-//     res, err := df.FreqItems(ctx, "id").Collect(ctx)
-//     assert.NoErrorf(t, err, "%+v", err)
-//     assert.Len(t, res, 1)
-//}
+func TestDataFrame_FreqItems(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select id % 4 as id from range(100)")
+       assert.NoError(t, err)
+       res, err := df.FreqItems(ctx, "id").Collect(ctx)
+       assert.NoErrorf(t, err, "%+v", err)
+       assert.Len(t, res, 1)
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 36da21f..85792b7 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -73,7 +73,7 @@ type DataFrame interface {
        // double value.
        Cov(ctx context.Context, col1, col2 string) (float64, error)
        // Collect returns the data rows of the current data frame.
-       Collect(ctx context.Context) ([]Row, error)
+       Collect(ctx context.Context) ([]types.Row, error)
        // CreateTempView creates or replaces a temporary view.
        CreateTempView(ctx context.Context, viewName string, replace, global 
bool) error
        // CreateOrReplaceTempView creates or replaces a temporary view and 
replaces the optional existing view.
@@ -113,7 +113,7 @@ type DataFrame interface {
        // FilterByString filters the data frame by a string condition.
        FilterByString(ctx context.Context, condition string) (DataFrame, error)
        // Returns the first row of the DataFrame.
-       First(ctx context.Context) (Row, error)
+       First(ctx context.Context) (types.Row, error)
        FreqItems(ctx context.Context, cols ...string) DataFrame
        FreqItemsWithSupport(ctx context.Context, support float64, cols 
...string) DataFrame
        // GetStorageLevel returns the storage level of the data frame.
@@ -122,7 +122,7 @@ type DataFrame interface {
        // can be performed on them. See GroupedData for all the available 
aggregate functions.
        GroupBy(cols ...column.Convertible) *GroupedData
        // Head is an alias for Limit
-       Head(ctx context.Context, limit int32) ([]Row, error)
+       Head(ctx context.Context, limit int32) ([]types.Row, error)
        // Intersect performs the set intersection of two data frames and only 
returns distinct rows.
        Intersect(ctx context.Context, other DataFrame) DataFrame
        // IntersectAll performs the set intersection of two data frames and 
returns all rows.
@@ -170,9 +170,9 @@ type DataFrame interface {
        // this function computes "count", "mean", "stddev", "min", "25%", 
"50%", "75%", "max".
        Summary(ctx context.Context, statistics ...string) DataFrame
        // Tail returns the last `limit` rows as a list of Row.
-       Tail(ctx context.Context, limit int32) ([]Row, error)
+       Tail(ctx context.Context, limit int32) ([]types.Row, error)
        // Take is an alias for Limit
-       Take(ctx context.Context, limit int32) ([]Row, error)
+       Take(ctx context.Context, limit int32) ([]types.Row, error)
        // ToArrow returns the Arrow representation of the DataFrame.
        ToArrow(ctx context.Context) (*arrow.Table, error)
        // Union is an alias for UnionAll
@@ -285,12 +285,12 @@ func (df *dataFrameImpl) CorrWithMethod(ctx 
context.Context, col1, col2 string,
                return 0, err
        }
 
-       values, err := types.ReadArrowTable(table)
+       values, err := types.ReadArrowTableToRows(table)
        if err != nil {
                return 0, err
        }
 
-       return values[0][0].(float64), nil
+       return values[0].At(0).(float64), nil
 }
 
 func (df *dataFrameImpl) Count(ctx context.Context) (int64, error) {
@@ -302,11 +302,8 @@ func (df *dataFrameImpl) Count(ctx context.Context) 
(int64, error) {
        if err != nil {
                return 0, err
        }
-       row, err := rows[0].Values()
-       if err != nil {
-               return 0, err
-       }
-       return row[0].(int64), nil
+
+       return rows[0].At(0).(int64), nil
 }
 
 func (df *dataFrameImpl) Cov(ctx context.Context, col1, col2 string) (float64, 
error) {
@@ -337,12 +334,12 @@ func (df *dataFrameImpl) Cov(ctx context.Context, col1, 
col2 string) (float64, e
                return 0, err
        }
 
-       values, err := types.ReadArrowTable(table)
+       values, err := types.ReadArrowTableToRows(table)
        if err != nil {
                return 0, err
        }
 
-       return values[0][0].(float64), nil
+       return values[0].At(0).(float64), nil
 }
 
 func (df *dataFrameImpl) PlanId() int64 {
@@ -454,29 +451,18 @@ func (df *dataFrameImpl) WriteResult(ctx context.Context, 
collector ResultCollec
                return sparkerrors.WithType(fmt.Errorf("failed to show 
dataframe: %w", err), sparkerrors.ExecutionError)
        }
 
-       schema, table, err := responseClient.ToTable()
+       _, table, err := responseClient.ToTable()
        if err != nil {
                return err
        }
 
-       rows := make([]Row, table.NumRows())
-
-       values, err := types.ReadArrowTable(table)
+       rows, err := types.ReadArrowTableToRows(table)
        if err != nil {
                return err
        }
 
-       for idx, v := range values {
-               row := NewRowWithSchema(v, schema)
-               rows[idx] = row
-       }
-
        for _, row := range rows {
-               values, err := row.Values()
-               if err != nil {
-                       return sparkerrors.WithType(fmt.Errorf(
-                               "failed to get values in the row: %w", err), 
sparkerrors.ReadError)
-               }
+               values := row.Values()
                collector.WriteRow(values)
        }
        return nil
@@ -492,30 +478,17 @@ func (df *dataFrameImpl) Schema(ctx context.Context) 
(*types.StructType, error)
        return types.ConvertProtoDataTypeToStructType(responseSchema)
 }
 
-func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) {
+func (df *dataFrameImpl) Collect(ctx context.Context) ([]types.Row, error) {
        responseClient, err := df.session.client.ExecutePlan(ctx, 
df.createPlan())
        if err != nil {
                return nil, sparkerrors.WithType(fmt.Errorf("failed to execute 
plan: %w", err), sparkerrors.ExecutionError)
        }
 
-       var schema *types.StructType
-       schema, table, err := responseClient.ToTable()
-       if err != nil {
-               return nil, err
-       }
-
-       rows := make([]Row, table.NumRows())
-
-       values, err := types.ReadArrowTable(table)
+       _, table, err := responseClient.ToTable()
        if err != nil {
                return nil, err
        }
-
-       for idx, v := range values {
-               row := NewRowWithSchema(v, schema)
-               rows[idx] = row
-       }
-       return rows, nil
+       return types.ReadArrowTableToRows(table)
 }
 
 func (df *dataFrameImpl) Write() DataFrameWriter {
@@ -852,7 +825,7 @@ func (df *dataFrameImpl) DropDuplicates(ctx 
context.Context, columns ...string)
        return NewDataFrame(df.session, rel), nil
 }
 
-func (df *dataFrameImpl) Tail(ctx context.Context, limit int32) ([]Row, error) 
{
+func (df *dataFrameImpl) Tail(ctx context.Context, limit int32) ([]types.Row, 
error) {
        rel := &proto.Relation{
                Common: &proto.RelationCommon{
                        PlanId: newPlanId(),
@@ -883,11 +856,11 @@ func (df *dataFrameImpl) Limit(ctx context.Context, limit 
int32) DataFrame {
        return NewDataFrame(df.session, rel)
 }
 
-func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]Row, error) 
{
+func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]types.Row, 
error) {
        return df.Limit(ctx, limit).Collect(ctx)
 }
 
-func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]Row, error) 
{
+func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]types.Row, 
error) {
        return df.Limit(ctx, limit).Collect(ctx)
 }
 
@@ -1260,7 +1233,7 @@ func (df *dataFrameImpl) Distinct(ctx context.Context) 
DataFrame {
        return NewDataFrame(df.session, rel)
 }
 
-func (df *dataFrameImpl) First(ctx context.Context) (Row, error) {
+func (df *dataFrameImpl) First(ctx context.Context) (types.Row, error) {
        rows, err := df.Head(ctx, 1)
        if err != nil {
                return nil, err
diff --git a/spark/sql/row.go b/spark/sql/row.go
deleted file mode 100644
index 4f821af..0000000
--- a/spark/sql/row.go
+++ /dev/null
@@ -1,48 +0,0 @@
-//
-// Licensed to the Apache Software Foundation (ASF) under one or more
-// contributor license agreements.  See the NOTICE file distributed with
-// this work for additional information regarding copyright ownership.
-// The ASF licenses this file to You under the Apache License, Version 2.0
-// (the "License"); you may not use this file except in compliance with
-// the License.  You may obtain a copy of the License at
-//
-//    http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sql
-
-import "github.com/apache/spark-connect-go/v35/spark/sql/types"
-
-// Row represents a row in a DataFrame.
-type Row interface {
-       // Schema returns the schema of the row.
-       Schema() (*types.StructType, error)
-       // Values returns the values of the row.
-       Values() ([]any, error)
-}
-
-// genericRowWithSchema represents a row in a DataFrame with schema.
-type genericRowWithSchema struct {
-       values []any
-       schema *types.StructType
-}
-
-func NewRowWithSchema(values []any, schema *types.StructType) Row {
-       return &genericRowWithSchema{
-               values: values,
-               schema: schema,
-       }
-}
-
-func (r *genericRowWithSchema) Schema() (*types.StructType, error) {
-       return r.schema, nil
-}
-
-func (r *genericRowWithSchema) Values() ([]any, error) {
-       return r.values, nil
-}
diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go
index a78681e..80f58f3 100644
--- a/spark/sql/sparksession_test.go
+++ b/spark/sql/sparksession_test.go
@@ -187,7 +187,7 @@ func TestWriteResultStreamsArrowResultToCollector(t 
*testing.T) {
        assert.NoError(t, err)
        rows, err := df.Collect(ctx)
        assert.NoError(t, err)
-       vals, err := rows[1].Values()
+       vals := rows[1].Values()
        assert.NoError(t, err)
        assert.Equal(t, []any{"str2"}, vals)
 }
diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go
index 02bd7c5..0131c82 100644
--- a/spark/sql/types/arrow.go
+++ b/spark/sql/types/arrow.go
@@ -29,159 +29,195 @@ import (
        "github.com/apache/spark-connect-go/v35/spark/sparkerrors"
 )
 
-func ReadArrowTable(table arrow.Table) ([][]any, error) {
-       numRows := table.NumRows()
-       numColumns := int(table.NumCols())
+func ReadArrowTableToRows(table arrow.Table) ([]Row, error) {
+       result := make([]Row, table.NumRows())
 
-       values := make([][]any, numRows)
-       for i := range values {
-               values[i] = make([]any, numColumns)
-       }
-
-       for columnIndex := 0; columnIndex < numColumns; columnIndex++ {
-               err := ReadArrowRecordColumn(table, columnIndex, values)
+       // For each column in the table, read the data and convert it to an 
array of any.
+       cols := make([][]any, table.NumCols())
+       for i := 0; i < int(table.NumCols()); i++ {
+               chunkedColumn := table.Column(i).Data()
+               column, err := readChunkedColumn(chunkedColumn)
                if err != nil {
                        return nil, err
                }
+               cols[i] = column
+       }
+
+       // Create a list of field names for the rows.
+       fieldNames := make([]string, table.NumCols())
+       for i, field := range table.Schema().Fields() {
+               fieldNames[i] = field.Name
        }
-       return values, nil
+
+       // Create the rows:
+       for i := 0; i < int(table.NumRows()); i++ {
+               row := make([]any, table.NumCols())
+               for j := 0; j < int(table.NumCols()); j++ {
+                       row[j] = cols[j][i]
+               }
+               r := &rowImpl{
+                       values:  row,
+                       offsets: make(map[string]int),
+               }
+               for j, fieldName := range fieldNames {
+                       r.offsets[fieldName] = j
+               }
+               result[i] = r
+       }
+
+       return result, nil
 }
 
-// readArrowRecordColumn reads all values in a column and stores them in values
-func ReadArrowRecordColumn(record arrow.Table, columnIndex int, values 
[][]any) error {
-       chunkedColumn := record.Column(columnIndex).Data()
-       dataTypeId := chunkedColumn.DataType().ID()
-       switch dataTypeId {
+func readArrayData(t arrow.Type, data arrow.ArrayData) ([]any, error) {
+       buf := make([]any, 0)
+       // Switch over the type t and append the values to buf.
+       switch t {
        case arrow.BOOL:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewBooleanData(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewBooleanData(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.INT8:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewInt8Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewInt8Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.INT16:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewInt16Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewInt16Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.INT32:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewInt32Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewInt32Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.INT64:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewInt64Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewInt64Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.FLOAT16:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewFloat16Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewFloat16Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.FLOAT32:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewFloat32Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewFloat32Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.FLOAT64:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewFloat64Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewFloat64Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.DECIMAL | arrow.DECIMAL128:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewDecimal128Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewDecimal128Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.DECIMAL256:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewDecimal256Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewDecimal256Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.STRING:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewStringData(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewStringData(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.BINARY:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewBinaryData(columnData.Data())
-                       for i := 0; rowIndex < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewBinaryData(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.TIMESTAMP:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewTimestampData(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
-                       }
+               data := array.NewTimestampData(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
                }
        case arrow.DATE64:
-               rowIndex := 0
-               for _, columnData := range chunkedColumn.Chunks() {
-                       vector := array.NewDate64Data(columnData.Data())
-                       for i := 0; i < columnData.Len(); i++ {
-                               values[rowIndex][columnIndex] = vector.Value(i)
-                               rowIndex += 1
+               data := array.NewDate64Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
+               }
+       case arrow.DATE32:
+               data := array.NewDate32Data(data)
+               for i := 0; i < data.Len(); i++ {
+                       buf = append(buf, data.Value(i))
+               }
+       case arrow.LIST:
+               data := array.NewListData(data)
+               values := data.ListValues()
+
+               res, err := readArrayData(values.DataType().ID(), values.Data())
+               if err != nil {
+                       return nil, err
+               }
+
+               for i := 0; i < data.Len(); i++ {
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                               continue
                        }
+                       start := data.Offsets()[i]
+                       end := data.Offsets()[i+1]
+                       // TODO: Unfortunately, this ends up being stored as a 
slice of slices of any. But not
+                       // the right type.
+                       buf = append(buf, res[start:end])
+               }
+       case arrow.MAP:
+               // For maps the data is stored as a list of key value pairs. So 
to extract the maps,
+               // we follow the same behavior as for lists but with two sub 
lists.
+               data := array.NewMapData(data)
+               keys := data.Keys()
+               values := data.Items()
+
+               keyValues, err := readArrayData(keys.DataType().ID(), 
keys.Data())
+               if err != nil {
+                       return nil, err
+               }
+               valueValues, err := readArrayData(values.DataType().ID(), 
values.Data())
+               if err != nil {
+                       return nil, err
+               }
+
+               for i := 0; i < data.Len(); i++ {
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                               continue
+                       }
+                       tmp := make(map[any]any)
+
+                       start := data.Offsets()[i]
+                       end := data.Offsets()[i+1]
+
+                       k := keyValues[start:end]
+                       v := valueValues[start:end]
+                       for j := 0; j < len(k); j++ {
+                               tmp[k[j]] = v[j]
+                       }
+                       buf = append(buf, tmp)
                }
        default:
-               return fmt.Errorf("unsupported arrow data type %s in column 
%d", dataTypeId.String(), columnIndex)
+               return nil, fmt.Errorf("unsupported arrow data type %s", 
t.String())
+       }
+       return buf, nil
+}
+
+func readChunkedColumn(chunked *arrow.Chunked) ([]any, error) {
+       buf := make([]any, 0)
+       for _, chunk := range chunked.Chunks() {
+               data := chunk.Data()
+               t := data.DataType().ID()
+               values, err := readArrayData(t, data)
+               if err != nil {
+                       return nil, err
+               }
+               buf = append(buf, values...)
        }
-       return nil
+       return buf, nil
 }
 
 func ReadArrowBatchToRecord(data []byte, schema *StructType) (arrow.Record, 
error) {
diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go
index 30fc14b..8e99c78 100644
--- a/spark/sql/types/arrow_test.go
+++ b/spark/sql/types/arrow_test.go
@@ -64,11 +64,11 @@ func TestShowArrowBatchData(t *testing.T) {
        require.NoError(t, err)
 
        table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record})
-       values, err := types.ReadArrowTable(table)
+       values, err := types.ReadArrowTableToRows(table)
        require.Nil(t, err)
        assert.Equal(t, 2, len(values))
-       assert.Equal(t, []any{"str1a\nstr1b"}, values[0])
-       assert.Equal(t, []any{"str2"}, values[1])
+       assert.Equal(t, []any{"str1a\nstr1b"}, values[0].Values())
+       assert.Equal(t, []any{"str2"}, values[1].Values())
 }
 
 func TestReadArrowRecord(t *testing.T) {
@@ -129,6 +129,14 @@ func TestReadArrowRecord(t *testing.T) {
                        Name: "date64_column",
                        Type: &arrow.Date64Type{},
                },
+               {
+                       Name: "array_int64_column",
+                       Type: arrow.ListOf(arrow.PrimitiveTypes.Int64),
+               },
+               {
+                       Name: "map_string_int32",
+                       Type: arrow.MapOf(arrow.BinaryTypes.String, 
arrow.PrimitiveTypes.Int32),
+               },
        }
        arrowSchema := arrow.NewSchema(arrowFields, nil)
        var buf bytes.Buffer
@@ -195,11 +203,32 @@ func TestReadArrowRecord(t *testing.T) {
        
recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953117000))
        
recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953118000))
 
+       i++
+       lb := recordBuilder.Field(i).(*array.ListBuilder)
+       lb.Append(true)
+       lb.ValueBuilder().(*array.Int64Builder).Append(1)
+       lb.ValueBuilder().(*array.Int64Builder).Append(-999231)
+
+       lb.Append(true)
+       lb.ValueBuilder().(*array.Int64Builder).Append(1)
+       lb.ValueBuilder().(*array.Int64Builder).Append(2)
+       lb.ValueBuilder().(*array.Int64Builder).Append(3)
+
+       i++
+       mb := recordBuilder.Field(i).(*array.MapBuilder)
+       mb.Append(true)
+       mb.KeyBuilder().(*array.StringBuilder).Append("key1")
+       mb.ItemBuilder().(*array.Int32Builder).Append(1)
+
+       mb.Append(true)
+       mb.KeyBuilder().(*array.StringBuilder).Append("key2")
+       mb.ItemBuilder().(*array.Int32Builder).Append(2)
+
        record := recordBuilder.NewRecord()
        defer record.Release()
 
        table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record})
-       values, err := types.ReadArrowTable(table)
+       values, err := types.ReadArrowTableToRows(table)
        require.Nil(t, err)
        assert.Equal(t, 2, len(values))
        assert.Equal(t, []any{
@@ -208,16 +237,20 @@ func TestReadArrowRecord(t *testing.T) {
                decimal128.FromI64(10000000), decimal256.FromI64(100000000),
                "str1", []byte("bytes1"),
                arrow.Timestamp(1686981953115000), 
arrow.Date64(1686981953117000),
+               []any{int64(1), int64(-999231)},
+               map[any]any{"key1": int32(1)},
        },
-               values[0])
+               values[0].Values())
        assert.Equal(t, []any{
                true, int8(2), int16(20), int32(200), int64(2000),
                float16.New(20000.1), float32(200000.1), 2000000.1,
                decimal128.FromI64(20000000), decimal256.FromI64(200000000),
                "str2", []byte("bytes2"),
                arrow.Timestamp(1686981953116000), 
arrow.Date64(1686981953118000),
+               []any{int64(1), int64(2), int64(3)},
+               map[any]any{"key2": int32(2)},
        },
-               values[1])
+               values[1].Values())
 }
 
 func TestReadArrowRecord_UnsupportedType(t *testing.T) {
@@ -242,7 +275,7 @@ func TestReadArrowRecord_UnsupportedType(t *testing.T) {
        defer record.Release()
 
        table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record})
-       _, err := types.ReadArrowTable(table)
+       _, err := types.ReadArrowTableToRows(table)
        require.NotNil(t, err)
 }
 
diff --git a/spark/sql/row_test.go b/spark/sql/types/row.go
similarity index 57%
rename from spark/sql/row_test.go
rename to spark/sql/types/row.go
index c4a384b..b73ef57 100644
--- a/spark/sql/row_test.go
+++ b/spark/sql/types/row.go
@@ -14,30 +14,41 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package sql
+package types
 
-import (
-       "testing"
+type Row interface {
+       At(index int) any
+       Value(name string) any
+       Values() []any
+       Len() int
+       FieldNames() []string
+}
+
+type rowImpl struct {
+       values  []any
+       offsets map[string]int
+}
 
-       "github.com/apache/spark-connect-go/v35/spark/sql/types"
+func (r *rowImpl) At(index int) any {
+       return r.values[index]
+}
 
-       "github.com/stretchr/testify/assert"
-)
+func (r *rowImpl) Value(name string) any {
+       return r.values[r.offsets[name]]
+}
+
+func (r *rowImpl) Values() []any {
+       return r.values
+}
 
-func TestSchema(t *testing.T) {
-       values := []any{1}
-       schema := &types.StructType{}
-       row := NewRowWithSchema(values, schema)
-       schema2, err := row.Schema()
-       assert.NoError(t, err)
-       assert.Equal(t, schema, schema2)
+func (r *rowImpl) Len() int {
+       return len(r.values)
 }
 
-func TestValues(t *testing.T) {
-       values := []any{1}
-       schema := &types.StructType{}
-       row := NewRowWithSchema(values, schema)
-       values2, err := row.Values()
-       assert.NoError(t, err)
-       assert.Equal(t, values, values2)
+func (r *rowImpl) FieldNames() []string {
+       names := make([]string, len(r.offsets))
+       for name := range r.offsets {
+               names = append(names, name)
+       }
+       return names
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to