This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new d05e1ba14 fix(go/adbc/driver/snowflake): split files properly after
reaching targetSize on ingestion (#2026)
d05e1ba14 is described below
commit d05e1ba142fb38d47749cb7904a7db4182093e64
Author: Joel Lubinitsky <[email protected]>
AuthorDate: Mon Jul 29 17:54:23 2024 -0400
fix(go/adbc/driver/snowflake): split files properly after reaching
targetSize on ingestion (#2026)
Fixes: #1997
**Core Changes**
- Change ingestion `writeParquet` function to use unbuffered writer,
skipping 0-row records to avoid recurrence of #1847
- Use parquet writer's internal `RowGroupTotalBytesWritten()` method to
track output file size in favor of `limitWriter`
- Unit test to validate that file cutoff occurs precisely when expected
**Secondary Changes**
- Bump arrow dependency to `v18` to pull in the changes from
[ARROW-43326](https://github.com/apache/arrow/pull/43326)
- Fix flightsql test that depends on hardcoded arrow version
---
go/adbc/driver/snowflake/bulk_ingestion.go | 44 ++++------
go/adbc/driver/snowflake/bulk_ingestion_test.go | 108 ++++++++++++++++++++++++
go/adbc/pkg/_tmpl/driver.go.tmpl | 8 +-
3 files changed, 127 insertions(+), 33 deletions(-)
diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go
b/go/adbc/driver/snowflake/bulk_ingestion.go
index 8b80ee49d..7f3d6bbd2 100644
--- a/go/adbc/driver/snowflake/bulk_ingestion.go
+++ b/go/adbc/driver/snowflake/bulk_ingestion.go
@@ -334,20 +334,31 @@ func writeParquet(
parquetProps *parquet.WriterProperties,
arrowProps pqarrow.ArrowWriterProperties,
) error {
- limitWr := &limitWriter{w: w, limit: targetSize}
- pqWriter, err := pqarrow.NewFileWriter(schema, limitWr, parquetProps,
arrowProps)
+ pqWriter, err := pqarrow.NewFileWriter(schema, w, parquetProps,
arrowProps)
if err != nil {
return err
}
defer pqWriter.Close()
+ var bytesWritten int64
for rec := range in {
- err = pqWriter.WriteBuffered(rec)
+ if rec.NumRows() == 0 {
+ rec.Release()
+ continue
+ }
+
+ err = pqWriter.Write(rec)
rec.Release()
if err != nil {
return err
}
- if limitWr.LimitExceeded() {
+
+ if targetSize < 0 {
+ continue
+ }
+
+ bytesWritten += pqWriter.RowGroupTotalBytesWritten()
+ if bytesWritten >= int64(targetSize) {
return nil
}
}
@@ -584,28 +595,3 @@ func (bp *bufferPool) PutBuffer(buf *bytes.Buffer) {
buf.Reset()
bp.Pool.Put(buf)
}
-
-// Wraps an io.Writer and specifies a limit.
-// Keeps track of how many bytes have been written and can report whether the
limit has been exceeded.
-// TODO(ARROW-39789): We prefer to use RowGroupTotalBytesWritten on the
ParquetWriter, but there seems to be a discrepency with the count.
-type limitWriter struct {
- w io.Writer
- limit int
-
- bytesWritten int
-}
-
-func (lw *limitWriter) Write(p []byte) (int, error) {
- n, err := lw.w.Write(p)
- lw.bytesWritten += n
-
- return n, err
-}
-
-func (lw *limitWriter) LimitExceeded() bool {
- if lw.limit > 0 {
- return lw.bytesWritten > lw.limit
- }
- // Limit disabled
- return false
-}
diff --git a/go/adbc/driver/snowflake/bulk_ingestion_test.go
b/go/adbc/driver/snowflake/bulk_ingestion_test.go
new file mode 100644
index 000000000..6ae0e3a4d
--- /dev/null
+++ b/go/adbc/driver/snowflake/bulk_ingestion_test.go
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "testing"
+
+ "github.com/apache/arrow/go/v18/arrow"
+ "github.com/apache/arrow/go/v18/arrow/array"
+ "github.com/apache/arrow/go/v18/arrow/memory"
+ "github.com/apache/arrow/go/v18/parquet/pqarrow"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestIngestBatchedParquetWithFileLimit(t *testing.T) {
+ var buf bytes.Buffer
+ ctx := context.Background()
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ ingestOpts := DefaultIngestOptions()
+ parquetProps, arrowProps := newWriterProps(mem, ingestOpts)
+
+ nCols := 3
+ nRecs := 10
+ nRows := 1000
+ targetFileSize := 10000
+
+ rec := makeRec(mem, nCols, nRows)
+ defer rec.Release()
+
+ // Create a temporary parquet writer and write a single row group so we
know
+ // approximately how many bytes it should take
+ tempWriter, err := pqarrow.NewFileWriter(rec.Schema(), &buf,
parquetProps, arrowProps)
+ require.NoError(t, err)
+
+ // Write 1 record and check the size before closing so footer bytes are
not included
+ require.NoError(t, tempWriter.Write(rec))
+ expectedRowGroupSize := buf.Len()
+ require.NoError(t, tempWriter.Close())
+
+ recs := make([]arrow.Record, nRecs)
+ for i := 0; i < nRecs; i++ {
+ recs[i] = rec
+ }
+
+ rdr, err := array.NewRecordReader(rec.Schema(), recs)
+ require.NoError(t, err)
+ defer rdr.Release()
+
+ records := make(chan arrow.Record)
+ go func() { assert.NoError(t, readRecords(ctx, rdr, records)) }()
+
+ buf.Reset()
+ // Expected to read multiple records but then stop after
targetFileSize, indicated by nil error
+ require.NoError(t, writeParquet(rdr.Schema(), &buf, records,
targetFileSize, parquetProps, arrowProps))
+
+ // Expect to exceed the targetFileSize but by no more than the size of
1 row group
+ assert.Greater(t, buf.Len(), targetFileSize)
+ assert.Less(t, buf.Len(), targetFileSize+expectedRowGroupSize)
+
+ // Drain the remaining records with no limit on file size, expect EOF
+ require.ErrorIs(t, writeParquet(rdr.Schema(), &buf, records, -1,
parquetProps, arrowProps), io.EOF)
+}
+
+func makeRec(mem memory.Allocator, nCols, nRows int) arrow.Record {
+ vals := make([]int8, nRows)
+ for val := 0; val < nRows; val++ {
+ vals[val] = int8(val)
+ }
+
+ bldr := array.NewInt8Builder(mem)
+ defer bldr.Release()
+
+ bldr.AppendValues(vals, nil)
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ fields := make([]arrow.Field, nCols)
+ cols := make([]arrow.Array, nCols)
+ for i := 0; i < nCols; i++ {
+ fields[i] = arrow.Field{Name: fmt.Sprintf("field_%d", i), Type:
arrow.PrimitiveTypes.Int8}
+ cols[i] = arr // array.NewRecord will retain these
+ }
+
+ schema := arrow.NewSchema(fields, nil)
+ return array.NewRecord(schema, cols, int64(nRows))
+}
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 8dc352c0e..77513949f 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -59,10 +59,10 @@ import (
"unsafe"
"github.com/apache/arrow-adbc/go/adbc"
- "github.com/apache/arrow/go/v17/arrow/array"
- "github.com/apache/arrow/go/v17/arrow/cdata"
- "github.com/apache/arrow/go/v17/arrow/memory"
- "github.com/apache/arrow/go/v17/arrow/memory/mallocator"
+ "github.com/apache/arrow/go/v18/arrow/array"
+ "github.com/apache/arrow/go/v18/arrow/cdata"
+ "github.com/apache/arrow/go/v18/arrow/memory"
+ "github.com/apache/arrow/go/v18/arrow/memory/mallocator"
)
// Must use malloc() to respect CGO rules