This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new d05e1ba14 fix(go/adbc/driver/snowflake): split files properly after 
reaching targetSize on ingestion (#2026)
d05e1ba14 is described below

commit d05e1ba142fb38d47749cb7904a7db4182093e64
Author: Joel Lubinitsky <[email protected]>
AuthorDate: Mon Jul 29 17:54:23 2024 -0400

    fix(go/adbc/driver/snowflake): split files properly after reaching 
targetSize on ingestion (#2026)
    
    Fixes: #1997
    
    **Core Changes**
    - Change ingestion `writeParquet` function to use unbuffered writer,
    skipping 0-row records to avoid recurrence of #1847
    - Use parquet writer's internal `RowGroupTotalBytesWritten()` method to
    track output file size in favor of `limitWriter`
    - Unit test to validate that file cutoff occurs precisely when expected
    
    **Secondary Changes**
    - Bump arrow dependency to `v18` to pull in the changes from
    [ARROW-43326](https://github.com/apache/arrow/pull/43326)
    - Fix flightsql test that depends on hardcoded arrow version
---
 go/adbc/driver/snowflake/bulk_ingestion.go      |  44 ++++------
 go/adbc/driver/snowflake/bulk_ingestion_test.go | 108 ++++++++++++++++++++++++
 go/adbc/pkg/_tmpl/driver.go.tmpl                |   8 +-
 3 files changed, 127 insertions(+), 33 deletions(-)

diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go 
b/go/adbc/driver/snowflake/bulk_ingestion.go
index 8b80ee49d..7f3d6bbd2 100644
--- a/go/adbc/driver/snowflake/bulk_ingestion.go
+++ b/go/adbc/driver/snowflake/bulk_ingestion.go
@@ -334,20 +334,31 @@ func writeParquet(
        parquetProps *parquet.WriterProperties,
        arrowProps pqarrow.ArrowWriterProperties,
 ) error {
-       limitWr := &limitWriter{w: w, limit: targetSize}
-       pqWriter, err := pqarrow.NewFileWriter(schema, limitWr, parquetProps, 
arrowProps)
+       pqWriter, err := pqarrow.NewFileWriter(schema, w, parquetProps, 
arrowProps)
        if err != nil {
                return err
        }
        defer pqWriter.Close()
 
+       var bytesWritten int64
        for rec := range in {
-               err = pqWriter.WriteBuffered(rec)
+               if rec.NumRows() == 0 {
+                       rec.Release()
+                       continue
+               }
+
+               err = pqWriter.Write(rec)
                rec.Release()
                if err != nil {
                        return err
                }
-               if limitWr.LimitExceeded() {
+
+               if targetSize < 0 {
+                       continue
+               }
+
+               bytesWritten += pqWriter.RowGroupTotalBytesWritten()
+               if bytesWritten >= int64(targetSize) {
                        return nil
                }
        }
@@ -584,28 +595,3 @@ func (bp *bufferPool) PutBuffer(buf *bytes.Buffer) {
        buf.Reset()
        bp.Pool.Put(buf)
 }
-
-// Wraps an io.Writer and specifies a limit.
-// Keeps track of how many bytes have been written and can report whether the 
limit has been exceeded.
-// TODO(ARROW-39789): We prefer to use RowGroupTotalBytesWritten on the 
ParquetWriter, but there seems to be a discrepency with the count.
-type limitWriter struct {
-       w     io.Writer
-       limit int
-
-       bytesWritten int
-}
-
-func (lw *limitWriter) Write(p []byte) (int, error) {
-       n, err := lw.w.Write(p)
-       lw.bytesWritten += n
-
-       return n, err
-}
-
-func (lw *limitWriter) LimitExceeded() bool {
-       if lw.limit > 0 {
-               return lw.bytesWritten > lw.limit
-       }
-       // Limit disabled
-       return false
-}
diff --git a/go/adbc/driver/snowflake/bulk_ingestion_test.go 
b/go/adbc/driver/snowflake/bulk_ingestion_test.go
new file mode 100644
index 000000000..6ae0e3a4d
--- /dev/null
+++ b/go/adbc/driver/snowflake/bulk_ingestion_test.go
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+       "bytes"
+       "context"
+       "fmt"
+       "io"
+       "testing"
+
+       "github.com/apache/arrow/go/v18/arrow"
+       "github.com/apache/arrow/go/v18/arrow/array"
+       "github.com/apache/arrow/go/v18/arrow/memory"
+       "github.com/apache/arrow/go/v18/parquet/pqarrow"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+)
+
+func TestIngestBatchedParquetWithFileLimit(t *testing.T) {
+       var buf bytes.Buffer
+       ctx := context.Background()
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       ingestOpts := DefaultIngestOptions()
+       parquetProps, arrowProps := newWriterProps(mem, ingestOpts)
+
+       nCols := 3
+       nRecs := 10
+       nRows := 1000
+       targetFileSize := 10000
+
+       rec := makeRec(mem, nCols, nRows)
+       defer rec.Release()
+
+       // Create a temporary parquet writer and write a single row group so we 
know
+       // approximately how many bytes it should take
+       tempWriter, err := pqarrow.NewFileWriter(rec.Schema(), &buf, 
parquetProps, arrowProps)
+       require.NoError(t, err)
+
+       // Write 1 record and check the size before closing so footer bytes are 
not included
+       require.NoError(t, tempWriter.Write(rec))
+       expectedRowGroupSize := buf.Len()
+       require.NoError(t, tempWriter.Close())
+
+       recs := make([]arrow.Record, nRecs)
+       for i := 0; i < nRecs; i++ {
+               recs[i] = rec
+       }
+
+       rdr, err := array.NewRecordReader(rec.Schema(), recs)
+       require.NoError(t, err)
+       defer rdr.Release()
+
+       records := make(chan arrow.Record)
+       go func() { assert.NoError(t, readRecords(ctx, rdr, records)) }()
+
+       buf.Reset()
+       // Expected to read multiple records but then stop after 
targetFileSize, indicated by nil error
+       require.NoError(t, writeParquet(rdr.Schema(), &buf, records, 
targetFileSize, parquetProps, arrowProps))
+
+       // Expect to exceed the targetFileSize but by no more than the size of 
1 row group
+       assert.Greater(t, buf.Len(), targetFileSize)
+       assert.Less(t, buf.Len(), targetFileSize+expectedRowGroupSize)
+
+       // Drain the remaining records with no limit on file size, expect EOF
+       require.ErrorIs(t, writeParquet(rdr.Schema(), &buf, records, -1, 
parquetProps, arrowProps), io.EOF)
+}
+
+func makeRec(mem memory.Allocator, nCols, nRows int) arrow.Record {
+       vals := make([]int8, nRows)
+       for val := 0; val < nRows; val++ {
+               vals[val] = int8(val)
+       }
+
+       bldr := array.NewInt8Builder(mem)
+       defer bldr.Release()
+
+       bldr.AppendValues(vals, nil)
+       arr := bldr.NewArray()
+       defer arr.Release()
+
+       fields := make([]arrow.Field, nCols)
+       cols := make([]arrow.Array, nCols)
+       for i := 0; i < nCols; i++ {
+               fields[i] = arrow.Field{Name: fmt.Sprintf("field_%d", i), Type: 
arrow.PrimitiveTypes.Int8}
+               cols[i] = arr // array.NewRecord will retain these
+       }
+
+       schema := arrow.NewSchema(fields, nil)
+       return array.NewRecord(schema, cols, int64(nRows))
+}
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 8dc352c0e..77513949f 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -59,10 +59,10 @@ import (
        "unsafe"
 
        "github.com/apache/arrow-adbc/go/adbc"
-       "github.com/apache/arrow/go/v17/arrow/array"
-       "github.com/apache/arrow/go/v17/arrow/cdata"
-       "github.com/apache/arrow/go/v17/arrow/memory"
-       "github.com/apache/arrow/go/v17/arrow/memory/mallocator"
+       "github.com/apache/arrow/go/v18/arrow/array"
+       "github.com/apache/arrow/go/v18/arrow/cdata"
+       "github.com/apache/arrow/go/v18/arrow/memory"
+       "github.com/apache/arrow/go/v18/arrow/memory/mallocator"
 )
 
 // Must use malloc() to respect CGO rules

Reply via email to