zeroshade commented on code in PR #1456:
URL: https://github.com/apache/arrow-adbc/pull/1456#discussion_r1465270544


##########
go/adbc/driver/snowflake/bulk_ingestion.go:
##########
@@ -0,0 +1,549 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+       "bufio"
+       "bytes"
+       "compress/flate"
+       "context"
+       "database/sql"
+       "database/sql/driver"
+       "errors"
+       "fmt"
+       "io"
+       "runtime"
+       "strings"
+       "sync"
+
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow/go/v15/arrow"
+       "github.com/apache/arrow/go/v15/arrow/array"
+       "github.com/apache/arrow/go/v15/arrow/memory"
+       "github.com/apache/arrow/go/v15/parquet"
+       "github.com/apache/arrow/go/v15/parquet/compress"
+       "github.com/apache/arrow/go/v15/parquet/pqarrow"
+       "github.com/snowflakedb/gosnowflake"
+       "golang.org/x/sync/errgroup"
+)
+
+const (
+       bindStageName            = "ADBC$BIND"
+       createTemporaryStageStmt = "CREATE OR REPLACE TEMPORARY STAGE " + 
bindStageName + " FILE_FORMAT = (TYPE = PARQUET USE_LOGICAL_TYPE = TRUE 
BINARY_AS_TEXT = FALSE)"
+       putQueryTmpl             = "PUT 'file:///tmp/placeholder/%s' @" + 
bindStageName + " OVERWRITE = TRUE"
+       copyQuery                = "COPY INTO IDENTIFIER(?) FROM @" + 
bindStageName + " MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE"
+       countQuery               = "SELECT COUNT(*) FROM IDENTIFIER(?)"
+       megabyte                 = 1024 * 1024
+)
+
+var (
+       defaultTargetFileSize    uint = 10 * megabyte
+       defaultWriterConcurrency uint = uint(runtime.NumCPU())
+       defaultUploadConcurrency uint = 8
+       defaultCopyConcurrency   uint = 4
+
+       defaultCompressionCodec compress.Compression = compress.Codecs.Snappy
+       defaultCompressionLevel int                  = flate.DefaultCompression
+)
+
+// Options for configuring bulk ingestion.
+//
+// Values should be updated with appropriate calls to stmt.SetOption().
+type ingestOptions struct {
+       // Approximate size of Parquet files written during ingestion.
+       //
+       // Actual size will be slightly larger, depending on size of 
footer/metadata.
+       // Default is 10 MB. If set to 0, file size has no limit. Cannot be 
negative.
+       targetFileSize uint
+       // Number of Parquet files to write in parallel.
+       //
+       // Default attempts to maximize workers based on logical cores 
detected, but
+       // may need to be adjusted if running in a constrained environment.
+       // If set to 0, default value is used. Cannot be negative.
+       writerConcurrency uint
+       // Number of Parquet files to upload in parallel.
+       //
+       // Greater concurrency can smooth out TCP congestion and help make use 
of
+       // available network bandwith, but will increase memory utilization.
+       // Default is 8. If set to 0, default value is used. Cannot be negative.
+       uploadConcurrency uint
+       // Maximum number of COPY operations to run concurrently.
+       //
+       // Bulk ingestion performance is optimized by executing COPY queries as 
files are
+       // still being uploaded. Snowflake COPY speed scales with warehouse 
size, so smaller
+       // warehouses may benefit from setting this value higher to ensure 
long-running
+       // COPY queries do not block newly uploaded files from being loaded.
+       // Default is 4. If set to 0, only a single COPY query will be executed 
as part of ingestion,
+       // once all files have finished uploading. Cannot be negative.
+       copyConcurrency uint
+       // Compression codec to use for Parquet files.
+       //
+       // When network speeds are high, it is generally faster to use a faster 
codec with
+       // a lower compression ratio. The opposite is true if the network is 
slow by CPU is
+       // available.
+       // Default is Snappy.
+       compressionCodec compress.Compression
+       // Compression level for Parquet files.
+       //
+       // The compression level is codec-specific. Some codecs do not support 
setting it,
+       // notably Snappy.
+       // Default is the default level for the specified compressionCodec.
+       compressionLevel int
+}
+
+func DefaultIngestOptions() *ingestOptions {
+       return &ingestOptions{
+               targetFileSize:    defaultTargetFileSize,
+               writerConcurrency: defaultWriterConcurrency,
+               uploadConcurrency: defaultUploadConcurrency,
+               copyConcurrency:   defaultCopyConcurrency,
+               compressionCodec:  defaultCompressionCodec,
+               compressionLevel:  defaultCompressionLevel,
+       }
+}
+
+// ingestRecord performs bulk ingestion of a single Record and returns the
+// number of rows affected.
+//
+// The Record must already be bound by calling stmt.Bind(), and will be 
released
+// and reset upon completion.
+func (st *statement) ingestRecord(ctx context.Context) (nrows int64, err 
error) {
+       defer func() {
+               // Record already released by writeParquet()
+               st.bound = nil
+       }()
+
+       parquetProps, arrowProps := newWriterProps(st.alloc, st.ingestOptions)
+       g := errgroup.Group{}
+
+       // writeParquet takes a channel of Records, but we only have one Record 
to write
+       recordCh := make(chan arrow.Record, 1)
+       recordCh <- st.bound
+       close(recordCh)
+
+       // Read the Record from the channel and write it into the provided 
writer
+       schema := st.bound.Schema()
+       r, w := io.Pipe()
+       bw := bufio.NewWriter(w)
+       g.Go(func() error {
+               defer r.Close()
+               defer bw.Flush()
+
+               err = writeParquet(schema, bw, recordCh, 0, parquetProps, 
arrowProps)
+               if err != io.EOF {
+                       return err
+               }
+               return nil
+       })
+
+       // Create a temporary stage, we can't start uploading until it has been 
created
+       _, err = st.cnxn.cn.ExecContext(ctx, createTemporaryStageStmt, nil)
+       if err != nil {
+               return
+       }
+
+       // Start uploading the file to Snowflake
+       fileName := "0.parquet" // Only writing 1 file, so use same name as 
first file written by ingestStream() for consistency
+       err = uploadStream(ctx, st.cnxn.cn, r, fileName)
+       if err != nil {
+               return
+       }
+
+       // Parquet writing is already done if the upload finished, so we're 
just checking for any errors
+       err = g.Wait()
+       if err != nil {
+               return
+       }
+
+       // Load the uploaded file into the target table
+       _, err = st.cnxn.cn.ExecContext(ctx, copyQuery, 
[]driver.NamedValue{{Value: st.targetTable}})
+       if err != nil {
+               return
+       }
+
+       // Check final row count of target table to get definitive rows affected
+       nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, st.targetTable)
+       return
+}
+
+// ingestStream performs bulk ingestion of a RecordReader and returns the
+// number of rows affected.
+//
+// The RecordReader must already be bound by calling stmt.BindStream(), and 
will
+// be released and reset upon completion.
+func (st *statement) ingestStream(ctx context.Context) (nrows int64, err 
error) {
+       defer func() {
+               st.streamBind.Release()
+               st.streamBind = nil
+       }()
+       defer func() {
+               // Always check the resulting row count, even in the case of an 
error. We may have ingested part of the data.
+               ctx := context.Background() // TODO(joellubi): switch to 
context.WithoutCancel(ctx) once we're on Go 1.21
+               n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, 
st.targetTable)
+               nrows = n
+
+               // Ingestion, row-count check, or both could have failed
+               // Wrap any failures as ADBC errors
+
+               // TODO(joellubi): simplify / improve with errors.Join(err, 
countErr) once we're on Go 1.20
+               if err == nil {
+                       err = errToAdbcErr(adbc.StatusInternal, countErr)
+                       return
+               }
+
+               // Failure in the pipeline itself
+               if errors.Is(err, context.Canceled) {
+                       err = errToAdbcErr(adbc.StatusCancelled, err)
+               } else {
+                       err = errToAdbcErr(adbc.StatusInternal, err)
+               }
+       }()
+
+       parquetProps, arrowProps := newWriterProps(st.alloc, st.ingestOptions)
+       g, gCtx := errgroup.WithContext(ctx)
+
+       // Read records into channel
+       records := make(chan arrow.Record, st.ingestOptions.writerConcurrency)
+       g.Go(func() error {
+               return readRecords(gCtx, st.streamBind, records)
+       })
+
+       // Read records from channel and write Parquet files in parallel to 
buffer pool
+       schema := st.streamBind.Schema()
+       pool := newBufferPool(int(st.ingestOptions.targetFileSize))
+       buffers := make(chan *bytes.Buffer, st.ingestOptions.writerConcurrency)
+       g.Go(func() error {
+               return runParallelParquetWriters(gCtx, schema, 
int(st.ingestOptions.targetFileSize), int(st.ingestOptions.writerConcurrency), 
parquetProps, arrowProps, pool.GetBuffer, records, buffers)

Review Comment:
   split this line for readability?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to