This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new 8432dd0 Cleanup row type and Arrow conversion
8432dd0 is described below
commit 8432dd0ed2849aa4add905962b139470e7806610
Author: Martin Grund <[email protected]>
AuthorDate: Thu Oct 3 04:13:08 2024 -0700
Cleanup row type and Arrow conversion
### What changes were proposed in this pull request?
This patch removes some of the unnecessary indirections of the previous
implementation of the Row type and addss support for deserializing LIST types
from Arrow as well.
### Why are the changes needed?
Compatibility / Ease of Use
### Does this PR introduce _any_ user-facing change?
Now the `Row` type is closer to the PySpark one avoiding the unnecessary
extraction to `row.values()` to access the values by offset or name.
### How was this patch tested?
Added Unit tests
Closes #72 from grundprinzip/cleanup_row_type.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
cmd/spark-connect-example-spark-session/main.go | 2 +-
internal/tests/integration/dataframe_test.go | 104 ++++-----
spark/sql/dataframe.go | 69 ++----
spark/sql/row.go | 48 -----
spark/sql/sparksession_test.go | 2 +-
spark/sql/types/arrow.go | 266 ++++++++++++++----------
spark/sql/types/arrow_test.go | 47 ++++-
spark/sql/{row_test.go => types/row.go} | 51 +++--
8 files changed, 282 insertions(+), 307 deletions(-)
diff --git a/cmd/spark-connect-example-spark-session/main.go
b/cmd/spark-connect-example-spark-session/main.go
index 6a1fe6a..516d87b 100644
--- a/cmd/spark-connect-example-spark-session/main.go
+++ b/cmd/spark-connect-example-spark-session/main.go
@@ -102,7 +102,7 @@ func main() {
log.Fatalf("Failed: %s", err)
}
- schema, err = rows[0].Schema()
+ schema, err = df.Schema(ctx)
if err != nil {
log.Fatalf("Failed: %s", err)
}
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index 61ebcdf..bb9e2aa 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -45,9 +45,7 @@ func TestDataFrame_Select(t *testing.T) {
assert.Equal(t, 100, len(res))
rowZero := res[0]
- vals, err := rowZero.Values()
- assert.NoError(t, err)
- assert.Equal(t, 2, len(vals))
+ assert.Equal(t, 2, rowZero.Len())
df, err = spark.Sql(ctx, "select * from range(100)")
assert.NoError(t, err)
@@ -69,9 +67,7 @@ func TestDataFrame_SelectExpr(t *testing.T) {
assert.Equal(t, 100, len(res))
row_zero := res[0]
- vals, err := row_zero.Values()
- assert.NoError(t, err)
- assert.Equal(t, 3, len(vals))
+ assert.Equal(t, 3, row_zero.Len())
}
func TestDataFrame_Alias(t *testing.T) {
@@ -98,10 +94,7 @@ func TestDataFrame_CrossJoin(t *testing.T) {
res, err := df.Collect(ctx)
assert.NoError(t, err)
assert.Equal(t, 100, len(res))
-
- v, e := res[0].Values()
- assert.NoError(t, e)
- assert.Equal(t, 2, len(v))
+ assert.Equal(t, 2, res[0].Len())
}
func TestDataFrame_GroupBy(t *testing.T) {
@@ -118,9 +111,8 @@ func TestDataFrame_GroupBy(t *testing.T) {
res, err = df.Collect(ctx)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
- vals, _ := res[0].Values()
- assert.Equal(t, "a", vals[0])
- assert.Equal(t, int64(10), vals[1])
+ assert.Equal(t, "a", res[0].At(0))
+ assert.Equal(t, int64(10), res[0].At(1))
}
func TestDataFrame_Count(t *testing.T) {
@@ -172,13 +164,10 @@ func TestSparkSession_CreateDataFrameWithSchema(t
*testing.T) {
res, err := df.Collect(ctx)
assert.NoError(t, err)
assert.Equal(t, 2, len(res))
-
- row1, err := res[0].Values()
- assert.NoError(t, err)
- assert.Len(t, row1, 3)
- assert.Equal(t, int32(1), row1[0])
- assert.Equal(t, 1.1, row1[1])
- assert.Equal(t, "a", row1[2])
+ assert.Equal(t, 3, res[0].Len())
+ assert.Equal(t, int32(1), res[0].At(0))
+ assert.Equal(t, 1.1, res[0].At(1))
+ assert.Equal(t, "a", res[0].At(2))
}
func TestDataFrame_Corr(t *testing.T) {
@@ -226,10 +215,8 @@ func TestDataFrame_WithColumn(t *testing.T) {
assert.Equal(t, 10, len(res))
// Check the values of the new column
for _, row := range res {
- vals, err := row.Values()
- assert.NoError(t, err)
- assert.Equal(t, 2, len(vals))
- assert.Equal(t, int64(1), vals[1])
+ assert.Equal(t, 2, row.Len())
+ assert.Equal(t, int64(1), row.At(1))
}
}
@@ -247,11 +234,9 @@ func TestDataFrame_WithColumns(t *testing.T) {
assert.Equal(t, 10, len(res))
// Check the values of the new columns
for _, row := range res {
- vals, err := row.Values()
- assert.NoError(t, err)
- assert.Equal(t, 3, len(vals))
- assert.Equal(t, int64(1), vals[1], "%v", vals)
- assert.Equal(t, int64(2), vals[2], "%v", vals)
+ assert.Equal(t, 3, row.Len())
+ assert.Equal(t, int64(1), row.At(1))
+ assert.Equal(t, int64(2), row.At(2))
}
}
@@ -291,10 +276,8 @@ func TestDataFrame_WithColumnRenamed(t *testing.T) {
assert.Equal(t, 10, len(res))
// Check the values of the new column
for i, row := range res {
- vals, err := row.Values()
- assert.NoError(t, err)
- assert.Equal(t, 1, len(vals))
- assert.Equal(t, int64(i), vals[0])
+ assert.Equal(t, 1, row.Len())
+ assert.Equal(t, int64(i), row.At(0))
}
// Test that renaming a non-existing column does not change anything.
@@ -397,12 +380,10 @@ func TestDataFrame_DropDuplicates(t *testing.T) {
rows, err := df.Collect(ctx)
assert.NoError(t, err)
assert.Equal(t, 2, len(rows))
- vals, _ := rows[0].Values()
- assert.Equal(t, "Alice", vals[0])
- assert.Equal(t, int32(5), vals[1])
- vals, _ = rows[1].Values()
- assert.Equal(t, "Alice", vals[0])
- assert.Equal(t, int32(10), vals[1])
+ assert.Equal(t, "Alice", rows[0].At(0))
+ assert.Equal(t, int32(5), rows[0].At(1))
+ assert.Equal(t, "Alice", rows[1].At(0))
+ assert.Equal(t, int32(10), rows[1].At(1))
// Test drop duplicates with column names
df, err = df.DropDuplicates(ctx, "name")
@@ -536,17 +517,13 @@ func TestDataFrame_Sort(t *testing.T) {
assert.NoError(t, err)
res, err := df.Head(ctx, 1)
assert.NoError(t, err)
- vals, err := res[0].Values()
- assert.NoError(t, err)
- assert.Equal(t, int64(9), vals[0])
+ assert.Equal(t, int64(9), res[0].At(0))
df, err = src.Sort(ctx, functions.Col("id").Asc())
assert.NoError(t, err)
res, err = df.Head(ctx, 1)
assert.NoError(t, err)
- vals, err = res[0].Values()
- assert.NoError(t, err)
- assert.Equal(t, int64(0), vals[0])
+ assert.Equal(t, int64(0), res[0].At(0))
}
func TestDataFrame_Join(t *testing.T) {
@@ -593,10 +570,8 @@ func TestDataFrame_Summary(t *testing.T) {
res, err := df.Summary(ctx, "count", "stddev").Collect(ctx)
assert.NoError(t, err)
assert.Len(t, res, 2)
- v, err := res[0].Values()
- assert.NoError(t, err)
- assert.Equal(t, "count", v[0])
- assert.Len(t, v, 4)
+ assert.Equal(t, "count", res[0].At(0))
+ assert.Equal(t, 4, res[0].Len())
}
func TestDataFrame_Pivot(t *testing.T) {
@@ -658,9 +633,7 @@ func TestDataFrame_First(t *testing.T) {
assert.NoError(t, err)
row, err := df.First(ctx)
assert.NoError(t, err)
- vals, err := row.Values()
- assert.NoError(t, err)
- assert.Equal(t, int64(0), vals[0])
+ assert.Equal(t, int64(0), row.At(0))
}
func TestDataFrame_Distinct(t *testing.T) {
@@ -690,12 +663,10 @@ func TestDataFrame_CrossTab(t *testing.T) {
res, err := df.Collect(ctx)
assert.NoError(t, err)
assert.Len(t, res, 3)
- v, err := res[0].Values()
- assert.NoError(t, err)
- assert.Equal(t, "1", v[0])
- assert.Equal(t, int64(0), v[1])
- assert.Equal(t, int64(2), v[2])
- assert.Equal(t, int64(0), v[3])
+ assert.Equal(t, "1", res[0].At(0))
+ assert.Equal(t, int64(0), res[0].At(1))
+ assert.Equal(t, int64(2), res[0].At(2))
+ assert.Equal(t, int64(0), res[0].At(3))
}
func TestDataFrame_SameSemantics(t *testing.T) {
@@ -717,12 +688,11 @@ func TestDataFrame_SemanticHash(t *testing.T) {
assert.NotEmpty(t, hash)
}
-// DISABLED: Because we cannot parse LIST types
-//func TestDataFrame_FreqItems(t *testing.T) {
-// ctx, spark := connect()
-// df, err := spark.Sql(ctx, "select * from range(10)")
-// assert.NoError(t, err)
-// res, err := df.FreqItems(ctx, "id").Collect(ctx)
-// assert.NoErrorf(t, err, "%+v", err)
-// assert.Len(t, res, 1)
-//}
+func TestDataFrame_FreqItems(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select id % 4 as id from range(100)")
+ assert.NoError(t, err)
+ res, err := df.FreqItems(ctx, "id").Collect(ctx)
+ assert.NoErrorf(t, err, "%+v", err)
+ assert.Len(t, res, 1)
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 36da21f..85792b7 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -73,7 +73,7 @@ type DataFrame interface {
// double value.
Cov(ctx context.Context, col1, col2 string) (float64, error)
// Collect returns the data rows of the current data frame.
- Collect(ctx context.Context) ([]Row, error)
+ Collect(ctx context.Context) ([]types.Row, error)
// CreateTempView creates or replaces a temporary view.
CreateTempView(ctx context.Context, viewName string, replace, global
bool) error
// CreateOrReplaceTempView creates or replaces a temporary view and
replaces the optional existing view.
@@ -113,7 +113,7 @@ type DataFrame interface {
// FilterByString filters the data frame by a string condition.
FilterByString(ctx context.Context, condition string) (DataFrame, error)
// Returns the first row of the DataFrame.
- First(ctx context.Context) (Row, error)
+ First(ctx context.Context) (types.Row, error)
FreqItems(ctx context.Context, cols ...string) DataFrame
FreqItemsWithSupport(ctx context.Context, support float64, cols
...string) DataFrame
// GetStorageLevel returns the storage level of the data frame.
@@ -122,7 +122,7 @@ type DataFrame interface {
// can be performed on them. See GroupedData for all the available
aggregate functions.
GroupBy(cols ...column.Convertible) *GroupedData
// Head is an alias for Limit
- Head(ctx context.Context, limit int32) ([]Row, error)
+ Head(ctx context.Context, limit int32) ([]types.Row, error)
// Intersect performs the set intersection of two data frames and only
returns distinct rows.
Intersect(ctx context.Context, other DataFrame) DataFrame
// IntersectAll performs the set intersection of two data frames and
returns all rows.
@@ -170,9 +170,9 @@ type DataFrame interface {
// this function computes "count", "mean", "stddev", "min", "25%",
"50%", "75%", "max".
Summary(ctx context.Context, statistics ...string) DataFrame
// Tail returns the last `limit` rows as a list of Row.
- Tail(ctx context.Context, limit int32) ([]Row, error)
+ Tail(ctx context.Context, limit int32) ([]types.Row, error)
// Take is an alias for Limit
- Take(ctx context.Context, limit int32) ([]Row, error)
+ Take(ctx context.Context, limit int32) ([]types.Row, error)
// ToArrow returns the Arrow representation of the DataFrame.
ToArrow(ctx context.Context) (*arrow.Table, error)
// Union is an alias for UnionAll
@@ -285,12 +285,12 @@ func (df *dataFrameImpl) CorrWithMethod(ctx
context.Context, col1, col2 string,
return 0, err
}
- values, err := types.ReadArrowTable(table)
+ values, err := types.ReadArrowTableToRows(table)
if err != nil {
return 0, err
}
- return values[0][0].(float64), nil
+ return values[0].At(0).(float64), nil
}
func (df *dataFrameImpl) Count(ctx context.Context) (int64, error) {
@@ -302,11 +302,8 @@ func (df *dataFrameImpl) Count(ctx context.Context)
(int64, error) {
if err != nil {
return 0, err
}
- row, err := rows[0].Values()
- if err != nil {
- return 0, err
- }
- return row[0].(int64), nil
+
+ return rows[0].At(0).(int64), nil
}
func (df *dataFrameImpl) Cov(ctx context.Context, col1, col2 string) (float64,
error) {
@@ -337,12 +334,12 @@ func (df *dataFrameImpl) Cov(ctx context.Context, col1,
col2 string) (float64, e
return 0, err
}
- values, err := types.ReadArrowTable(table)
+ values, err := types.ReadArrowTableToRows(table)
if err != nil {
return 0, err
}
- return values[0][0].(float64), nil
+ return values[0].At(0).(float64), nil
}
func (df *dataFrameImpl) PlanId() int64 {
@@ -454,29 +451,18 @@ func (df *dataFrameImpl) WriteResult(ctx context.Context,
collector ResultCollec
return sparkerrors.WithType(fmt.Errorf("failed to show
dataframe: %w", err), sparkerrors.ExecutionError)
}
- schema, table, err := responseClient.ToTable()
+ _, table, err := responseClient.ToTable()
if err != nil {
return err
}
- rows := make([]Row, table.NumRows())
-
- values, err := types.ReadArrowTable(table)
+ rows, err := types.ReadArrowTableToRows(table)
if err != nil {
return err
}
- for idx, v := range values {
- row := NewRowWithSchema(v, schema)
- rows[idx] = row
- }
-
for _, row := range rows {
- values, err := row.Values()
- if err != nil {
- return sparkerrors.WithType(fmt.Errorf(
- "failed to get values in the row: %w", err),
sparkerrors.ReadError)
- }
+ values := row.Values()
collector.WriteRow(values)
}
return nil
@@ -492,30 +478,17 @@ func (df *dataFrameImpl) Schema(ctx context.Context)
(*types.StructType, error)
return types.ConvertProtoDataTypeToStructType(responseSchema)
}
-func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) {
+func (df *dataFrameImpl) Collect(ctx context.Context) ([]types.Row, error) {
responseClient, err := df.session.client.ExecutePlan(ctx,
df.createPlan())
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to execute
plan: %w", err), sparkerrors.ExecutionError)
}
- var schema *types.StructType
- schema, table, err := responseClient.ToTable()
- if err != nil {
- return nil, err
- }
-
- rows := make([]Row, table.NumRows())
-
- values, err := types.ReadArrowTable(table)
+ _, table, err := responseClient.ToTable()
if err != nil {
return nil, err
}
-
- for idx, v := range values {
- row := NewRowWithSchema(v, schema)
- rows[idx] = row
- }
- return rows, nil
+ return types.ReadArrowTableToRows(table)
}
func (df *dataFrameImpl) Write() DataFrameWriter {
@@ -852,7 +825,7 @@ func (df *dataFrameImpl) DropDuplicates(ctx
context.Context, columns ...string)
return NewDataFrame(df.session, rel), nil
}
-func (df *dataFrameImpl) Tail(ctx context.Context, limit int32) ([]Row, error)
{
+func (df *dataFrameImpl) Tail(ctx context.Context, limit int32) ([]types.Row,
error) {
rel := &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
@@ -883,11 +856,11 @@ func (df *dataFrameImpl) Limit(ctx context.Context, limit
int32) DataFrame {
return NewDataFrame(df.session, rel)
}
-func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]Row, error)
{
+func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]types.Row,
error) {
return df.Limit(ctx, limit).Collect(ctx)
}
-func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]Row, error)
{
+func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]types.Row,
error) {
return df.Limit(ctx, limit).Collect(ctx)
}
@@ -1260,7 +1233,7 @@ func (df *dataFrameImpl) Distinct(ctx context.Context)
DataFrame {
return NewDataFrame(df.session, rel)
}
-func (df *dataFrameImpl) First(ctx context.Context) (Row, error) {
+func (df *dataFrameImpl) First(ctx context.Context) (types.Row, error) {
rows, err := df.Head(ctx, 1)
if err != nil {
return nil, err
diff --git a/spark/sql/row.go b/spark/sql/row.go
deleted file mode 100644
index 4f821af..0000000
--- a/spark/sql/row.go
+++ /dev/null
@@ -1,48 +0,0 @@
-//
-// Licensed to the Apache Software Foundation (ASF) under one or more
-// contributor license agreements. See the NOTICE file distributed with
-// this work for additional information regarding copyright ownership.
-// The ASF licenses this file to You under the Apache License, Version 2.0
-// (the "License"); you may not use this file except in compliance with
-// the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sql
-
-import "github.com/apache/spark-connect-go/v35/spark/sql/types"
-
-// Row represents a row in a DataFrame.
-type Row interface {
- // Schema returns the schema of the row.
- Schema() (*types.StructType, error)
- // Values returns the values of the row.
- Values() ([]any, error)
-}
-
-// genericRowWithSchema represents a row in a DataFrame with schema.
-type genericRowWithSchema struct {
- values []any
- schema *types.StructType
-}
-
-func NewRowWithSchema(values []any, schema *types.StructType) Row {
- return &genericRowWithSchema{
- values: values,
- schema: schema,
- }
-}
-
-func (r *genericRowWithSchema) Schema() (*types.StructType, error) {
- return r.schema, nil
-}
-
-func (r *genericRowWithSchema) Values() ([]any, error) {
- return r.values, nil
-}
diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go
index a78681e..80f58f3 100644
--- a/spark/sql/sparksession_test.go
+++ b/spark/sql/sparksession_test.go
@@ -187,7 +187,7 @@ func TestWriteResultStreamsArrowResultToCollector(t
*testing.T) {
assert.NoError(t, err)
rows, err := df.Collect(ctx)
assert.NoError(t, err)
- vals, err := rows[1].Values()
+ vals := rows[1].Values()
assert.NoError(t, err)
assert.Equal(t, []any{"str2"}, vals)
}
diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go
index 02bd7c5..0131c82 100644
--- a/spark/sql/types/arrow.go
+++ b/spark/sql/types/arrow.go
@@ -29,159 +29,195 @@ import (
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
)
-func ReadArrowTable(table arrow.Table) ([][]any, error) {
- numRows := table.NumRows()
- numColumns := int(table.NumCols())
+func ReadArrowTableToRows(table arrow.Table) ([]Row, error) {
+ result := make([]Row, table.NumRows())
- values := make([][]any, numRows)
- for i := range values {
- values[i] = make([]any, numColumns)
- }
-
- for columnIndex := 0; columnIndex < numColumns; columnIndex++ {
- err := ReadArrowRecordColumn(table, columnIndex, values)
+ // For each column in the table, read the data and convert it to an
array of any.
+ cols := make([][]any, table.NumCols())
+ for i := 0; i < int(table.NumCols()); i++ {
+ chunkedColumn := table.Column(i).Data()
+ column, err := readChunkedColumn(chunkedColumn)
if err != nil {
return nil, err
}
+ cols[i] = column
+ }
+
+ // Create a list of field names for the rows.
+ fieldNames := make([]string, table.NumCols())
+ for i, field := range table.Schema().Fields() {
+ fieldNames[i] = field.Name
}
- return values, nil
+
+ // Create the rows:
+ for i := 0; i < int(table.NumRows()); i++ {
+ row := make([]any, table.NumCols())
+ for j := 0; j < int(table.NumCols()); j++ {
+ row[j] = cols[j][i]
+ }
+ r := &rowImpl{
+ values: row,
+ offsets: make(map[string]int),
+ }
+ for j, fieldName := range fieldNames {
+ r.offsets[fieldName] = j
+ }
+ result[i] = r
+ }
+
+ return result, nil
}
-// readArrowRecordColumn reads all values in a column and stores them in values
-func ReadArrowRecordColumn(record arrow.Table, columnIndex int, values
[][]any) error {
- chunkedColumn := record.Column(columnIndex).Data()
- dataTypeId := chunkedColumn.DataType().ID()
- switch dataTypeId {
+func readArrayData(t arrow.Type, data arrow.ArrayData) ([]any, error) {
+ buf := make([]any, 0)
+ // Switch over the type t and append the values to buf.
+ switch t {
case arrow.BOOL:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewBooleanData(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewBooleanData(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.INT8:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewInt8Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewInt8Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.INT16:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewInt16Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewInt16Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.INT32:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewInt32Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewInt32Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.INT64:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewInt64Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewInt64Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.FLOAT16:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewFloat16Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewFloat16Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.FLOAT32:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewFloat32Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewFloat32Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.FLOAT64:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewFloat64Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewFloat64Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.DECIMAL | arrow.DECIMAL128:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewDecimal128Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewDecimal128Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.DECIMAL256:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewDecimal256Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewDecimal256Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.STRING:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewStringData(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewStringData(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.BINARY:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewBinaryData(columnData.Data())
- for i := 0; rowIndex < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewBinaryData(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.TIMESTAMP:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewTimestampData(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
- }
+ data := array.NewTimestampData(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
}
case arrow.DATE64:
- rowIndex := 0
- for _, columnData := range chunkedColumn.Chunks() {
- vector := array.NewDate64Data(columnData.Data())
- for i := 0; i < columnData.Len(); i++ {
- values[rowIndex][columnIndex] = vector.Value(i)
- rowIndex += 1
+ data := array.NewDate64Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
+ }
+ case arrow.DATE32:
+ data := array.NewDate32Data(data)
+ for i := 0; i < data.Len(); i++ {
+ buf = append(buf, data.Value(i))
+ }
+ case arrow.LIST:
+ data := array.NewListData(data)
+ values := data.ListValues()
+
+ res, err := readArrayData(values.DataType().ID(), values.Data())
+ if err != nil {
+ return nil, err
+ }
+
+ for i := 0; i < data.Len(); i++ {
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ continue
}
+ start := data.Offsets()[i]
+ end := data.Offsets()[i+1]
+ // TODO: Unfortunately, this ends up being stored as a
slice of slices of any. But not
+ // the right type.
+ buf = append(buf, res[start:end])
+ }
+ case arrow.MAP:
+ // For maps the data is stored as a list of key value pairs. So
to extract the maps,
+ // we follow the same behavior as for lists but with two sub
lists.
+ data := array.NewMapData(data)
+ keys := data.Keys()
+ values := data.Items()
+
+ keyValues, err := readArrayData(keys.DataType().ID(),
keys.Data())
+ if err != nil {
+ return nil, err
+ }
+ valueValues, err := readArrayData(values.DataType().ID(),
values.Data())
+ if err != nil {
+ return nil, err
+ }
+
+ for i := 0; i < data.Len(); i++ {
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ continue
+ }
+ tmp := make(map[any]any)
+
+ start := data.Offsets()[i]
+ end := data.Offsets()[i+1]
+
+ k := keyValues[start:end]
+ v := valueValues[start:end]
+ for j := 0; j < len(k); j++ {
+ tmp[k[j]] = v[j]
+ }
+ buf = append(buf, tmp)
}
default:
- return fmt.Errorf("unsupported arrow data type %s in column
%d", dataTypeId.String(), columnIndex)
+ return nil, fmt.Errorf("unsupported arrow data type %s",
t.String())
+ }
+ return buf, nil
+}
+
+func readChunkedColumn(chunked *arrow.Chunked) ([]any, error) {
+ buf := make([]any, 0)
+ for _, chunk := range chunked.Chunks() {
+ data := chunk.Data()
+ t := data.DataType().ID()
+ values, err := readArrayData(t, data)
+ if err != nil {
+ return nil, err
+ }
+ buf = append(buf, values...)
}
- return nil
+ return buf, nil
}
func ReadArrowBatchToRecord(data []byte, schema *StructType) (arrow.Record,
error) {
diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go
index 30fc14b..8e99c78 100644
--- a/spark/sql/types/arrow_test.go
+++ b/spark/sql/types/arrow_test.go
@@ -64,11 +64,11 @@ func TestShowArrowBatchData(t *testing.T) {
require.NoError(t, err)
table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record})
- values, err := types.ReadArrowTable(table)
+ values, err := types.ReadArrowTableToRows(table)
require.Nil(t, err)
assert.Equal(t, 2, len(values))
- assert.Equal(t, []any{"str1a\nstr1b"}, values[0])
- assert.Equal(t, []any{"str2"}, values[1])
+ assert.Equal(t, []any{"str1a\nstr1b"}, values[0].Values())
+ assert.Equal(t, []any{"str2"}, values[1].Values())
}
func TestReadArrowRecord(t *testing.T) {
@@ -129,6 +129,14 @@ func TestReadArrowRecord(t *testing.T) {
Name: "date64_column",
Type: &arrow.Date64Type{},
},
+ {
+ Name: "array_int64_column",
+ Type: arrow.ListOf(arrow.PrimitiveTypes.Int64),
+ },
+ {
+ Name: "map_string_int32",
+ Type: arrow.MapOf(arrow.BinaryTypes.String,
arrow.PrimitiveTypes.Int32),
+ },
}
arrowSchema := arrow.NewSchema(arrowFields, nil)
var buf bytes.Buffer
@@ -195,11 +203,32 @@ func TestReadArrowRecord(t *testing.T) {
recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953117000))
recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953118000))
+ i++
+ lb := recordBuilder.Field(i).(*array.ListBuilder)
+ lb.Append(true)
+ lb.ValueBuilder().(*array.Int64Builder).Append(1)
+ lb.ValueBuilder().(*array.Int64Builder).Append(-999231)
+
+ lb.Append(true)
+ lb.ValueBuilder().(*array.Int64Builder).Append(1)
+ lb.ValueBuilder().(*array.Int64Builder).Append(2)
+ lb.ValueBuilder().(*array.Int64Builder).Append(3)
+
+ i++
+ mb := recordBuilder.Field(i).(*array.MapBuilder)
+ mb.Append(true)
+ mb.KeyBuilder().(*array.StringBuilder).Append("key1")
+ mb.ItemBuilder().(*array.Int32Builder).Append(1)
+
+ mb.Append(true)
+ mb.KeyBuilder().(*array.StringBuilder).Append("key2")
+ mb.ItemBuilder().(*array.Int32Builder).Append(2)
+
record := recordBuilder.NewRecord()
defer record.Release()
table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record})
- values, err := types.ReadArrowTable(table)
+ values, err := types.ReadArrowTableToRows(table)
require.Nil(t, err)
assert.Equal(t, 2, len(values))
assert.Equal(t, []any{
@@ -208,16 +237,20 @@ func TestReadArrowRecord(t *testing.T) {
decimal128.FromI64(10000000), decimal256.FromI64(100000000),
"str1", []byte("bytes1"),
arrow.Timestamp(1686981953115000),
arrow.Date64(1686981953117000),
+ []any{int64(1), int64(-999231)},
+ map[any]any{"key1": int32(1)},
},
- values[0])
+ values[0].Values())
assert.Equal(t, []any{
true, int8(2), int16(20), int32(200), int64(2000),
float16.New(20000.1), float32(200000.1), 2000000.1,
decimal128.FromI64(20000000), decimal256.FromI64(200000000),
"str2", []byte("bytes2"),
arrow.Timestamp(1686981953116000),
arrow.Date64(1686981953118000),
+ []any{int64(1), int64(2), int64(3)},
+ map[any]any{"key2": int32(2)},
},
- values[1])
+ values[1].Values())
}
func TestReadArrowRecord_UnsupportedType(t *testing.T) {
@@ -242,7 +275,7 @@ func TestReadArrowRecord_UnsupportedType(t *testing.T) {
defer record.Release()
table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record})
- _, err := types.ReadArrowTable(table)
+ _, err := types.ReadArrowTableToRows(table)
require.NotNil(t, err)
}
diff --git a/spark/sql/row_test.go b/spark/sql/types/row.go
similarity index 57%
rename from spark/sql/row_test.go
rename to spark/sql/types/row.go
index c4a384b..b73ef57 100644
--- a/spark/sql/row_test.go
+++ b/spark/sql/types/row.go
@@ -14,30 +14,41 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package sql
+package types
-import (
- "testing"
+type Row interface {
+ At(index int) any
+ Value(name string) any
+ Values() []any
+ Len() int
+ FieldNames() []string
+}
+
+type rowImpl struct {
+ values []any
+ offsets map[string]int
+}
- "github.com/apache/spark-connect-go/v35/spark/sql/types"
+func (r *rowImpl) At(index int) any {
+ return r.values[index]
+}
- "github.com/stretchr/testify/assert"
-)
+func (r *rowImpl) Value(name string) any {
+ return r.values[r.offsets[name]]
+}
+
+func (r *rowImpl) Values() []any {
+ return r.values
+}
-func TestSchema(t *testing.T) {
- values := []any{1}
- schema := &types.StructType{}
- row := NewRowWithSchema(values, schema)
- schema2, err := row.Schema()
- assert.NoError(t, err)
- assert.Equal(t, schema, schema2)
+func (r *rowImpl) Len() int {
+ return len(r.values)
}
-func TestValues(t *testing.T) {
- values := []any{1}
- schema := &types.StructType{}
- row := NewRowWithSchema(values, schema)
- values2, err := row.Values()
- assert.NoError(t, err)
- assert.Equal(t, values, values2)
+func (r *rowImpl) FieldNames() []string {
+ names := make([]string, len(r.offsets))
+ for name := range r.offsets {
+ names = append(names, name)
+ }
+ return names
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]