This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f0bdb367 feat(go/adbc/driver/snowflake): support parameter binding 
(#1808)
8f0bdb367 is described below

commit 8f0bdb367708f36183e7aa3691ff14c17c28d4e2
Author: David Li <[email protected]>
AuthorDate: Tue May 7 23:11:25 2024 +0900

    feat(go/adbc/driver/snowflake): support parameter binding (#1808)
    
    Fixes #1144.
---
 c/driver/snowflake/snowflake_test.cc      |   2 +-
 go/adbc/driver/snowflake/binding.go       | 153 ++++++++++++++++++++++++++++++
 go/adbc/driver/snowflake/concat_reader.go | 107 +++++++++++++++++++++
 go/adbc/driver/snowflake/driver.go        |   2 +-
 go/adbc/driver/snowflake/statement.go     |  56 ++++++++++-
 go/adbc/go.mod                            |   4 +-
 go/adbc/go.sum                            |   8 +-
 7 files changed, 321 insertions(+), 11 deletions(-)

diff --git a/c/driver/snowflake/snowflake_test.cc 
b/c/driver/snowflake/snowflake_test.cc
index 0fe07ecbd..a4d742491 100644
--- a/c/driver/snowflake/snowflake_test.cc
+++ b/c/driver/snowflake/snowflake_test.cc
@@ -146,7 +146,7 @@ class SnowflakeQuirks : public 
adbc_validation::DriverQuirks {
   bool supports_metadata_current_catalog() const override { return false; }
   bool supports_metadata_current_db_schema() const override { return false; }
   bool supports_partitioned_data() const override { return false; }
-  bool supports_dynamic_parameter_binding() const override { return false; }
+  bool supports_dynamic_parameter_binding() const override { return true; }
   bool supports_error_on_incompatible_schema() const override { return false; }
   bool ddl_implicit_commit_txn() const override { return true; }
   std::string db_schema() const override { return schema_; }
diff --git a/go/adbc/driver/snowflake/binding.go 
b/go/adbc/driver/snowflake/binding.go
new file mode 100644
index 000000000..7f6878945
--- /dev/null
+++ b/go/adbc/driver/snowflake/binding.go
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+       "database/sql"
+       "database/sql/driver"
+       "fmt"
+       "io"
+
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow/go/v17/arrow"
+       "github.com/apache/arrow/go/v17/arrow/array"
+)
+
+func convertArrowToNamedValue(batch arrow.Record, index int) 
([]driver.NamedValue, error) {
+       // see goTypeToSnowflake in gosnowflake
+       // technically, snowflake can bind an array of values at once, but
+       // only for INSERT, so we can't take advantage of that without
+       // analyzing the query ourselves
+       params := make([]driver.NamedValue, batch.NumCols())
+       for i, field := range batch.Schema().Fields() {
+               rawColumn := batch.Column(i)
+               params[i].Ordinal = i + 1
+               switch column := rawColumn.(type) {
+               case *array.Boolean:
+                       params[i].Value = sql.NullBool{
+                               Bool:  column.Value(index),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Float32:
+                       // Snowflake only recognizes float64
+                       params[i].Value = sql.NullFloat64{
+                               Float64: float64(column.Value(index)),
+                               Valid:   column.IsValid(index),
+                       }
+               case *array.Float64:
+                       params[i].Value = sql.NullFloat64{
+                               Float64: column.Value(index),
+                               Valid:   column.IsValid(index),
+                       }
+               case *array.Int8:
+                       // Snowflake only recognizes int64
+                       params[i].Value = sql.NullInt64{
+                               Int64: int64(column.Value(index)),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Int16:
+                       params[i].Value = sql.NullInt64{
+                               Int64: int64(column.Value(index)),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Int32:
+                       params[i].Value = sql.NullInt64{
+                               Int64: int64(column.Value(index)),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Int64:
+                       params[i].Value = sql.NullInt64{
+                               Int64: column.Value(index),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.String:
+                       params[i].Value = sql.NullString{
+                               String: column.Value(index),
+                               Valid:  column.IsValid(index),
+                       }
+               case *array.LargeString:
+                       params[i].Value = sql.NullString{
+                               String: column.Value(index),
+                               Valid:  column.IsValid(index),
+                       }
+               default:
+                       return nil, adbc.Error{
+                               Code: adbc.StatusNotImplemented,
+                               Msg:  fmt.Sprintf("[Snowflake] Unsupported bind 
param '%s' type %s", field.Name, field.Type.String()),
+                       }
+               }
+       }
+       return params, nil
+}
+
+type snowflakeBindReader struct {
+       doQuery      func([]driver.NamedValue) (array.RecordReader, error)
+       currentBatch arrow.Record
+       nextIndex    int64
+       // may be nil if we bound only a batch
+       stream array.RecordReader
+}
+
+func (r *snowflakeBindReader) Release() {
+       if r.currentBatch != nil {
+               r.currentBatch.Release()
+               r.currentBatch = nil
+       }
+       if r.stream != nil {
+               r.stream.Release()
+               r.stream = nil
+       }
+}
+
+func (r *snowflakeBindReader) Next() (array.RecordReader, error) {
+       params, err := r.NextParams()
+       if err != nil {
+               // includes EOF
+               return nil, err
+       }
+       return r.doQuery(params)
+}
+
+func (r *snowflakeBindReader) NextParams() ([]driver.NamedValue, error) {
+       for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
+               // We can be used both by binding a stream or by binding a
+               // batch. In the latter case, we have to release the batch,
+               // but not in the former case. Unify the cases by always
+               // releasing the batch, adding an "extra" retain so that the
+               // release does not cause issues.
+               if r.currentBatch != nil {
+                       r.currentBatch.Release()
+               }
+               r.currentBatch = nil
+               if r.stream != nil && r.stream.Next() {
+                       r.currentBatch = r.stream.Record()
+                       r.currentBatch.Retain()
+                       r.nextIndex = 0
+                       continue
+               } else if r.stream != nil && r.stream.Err() != nil {
+                       return nil, r.stream.Err()
+               } else {
+                       // no more params
+                       return nil, io.EOF
+               }
+       }
+
+       params, err := convertArrowToNamedValue(r.currentBatch, 
int(r.nextIndex))
+       r.nextIndex++
+       return params, err
+}
diff --git a/go/adbc/driver/snowflake/concat_reader.go 
b/go/adbc/driver/snowflake/concat_reader.go
new file mode 100644
index 000000000..c04a81a21
--- /dev/null
+++ b/go/adbc/driver/snowflake/concat_reader.go
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+       "io"
+       "sync/atomic"
+
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow/go/v17/arrow"
+       "github.com/apache/arrow/go/v17/arrow/array"
+)
+
+type readerIter interface {
+       Release()
+
+       Next() (array.RecordReader, error)
+}
+
+type concatReader struct {
+       refCount      atomic.Int64
+       readers       readerIter
+       currentReader array.RecordReader
+       schema        *arrow.Schema
+       err           error
+}
+
+func (r *concatReader) nextReader() {
+       if r.currentReader != nil {
+               r.currentReader.Release()
+               r.currentReader = nil
+       }
+       reader, err := r.readers.Next()
+       if err == io.EOF {
+               r.currentReader = nil
+       } else if err != nil {
+               r.err = err
+       } else {
+               // May be nil
+               r.currentReader = reader
+       }
+}
+func (r *concatReader) Init(readers readerIter) error {
+       r.readers = readers
+       r.refCount.Store(1)
+       r.nextReader()
+       if r.err != nil {
+               r.Release()
+               return r.err
+       } else if r.currentReader == nil {
+               r.Release()
+               r.err = adbc.Error{
+                       Code: adbc.StatusInternal,
+                       Msg:  "[Snowflake] No data in this stream",
+               }
+               return r.err
+       }
+       r.schema = r.currentReader.Schema()
+       return nil
+}
+func (r *concatReader) Retain() {
+       r.refCount.Add(1)
+}
+func (r *concatReader) Release() {
+       if r.refCount.Add(-1) == 0 {
+               if r.currentReader != nil {
+                       r.currentReader.Release()
+               }
+               r.readers.Release()
+       }
+}
+func (r *concatReader) Schema() *arrow.Schema {
+       if r.schema == nil {
+               panic("did not call concatReader.Init")
+       }
+       return r.schema
+}
+func (r *concatReader) Next() bool {
+       for r.currentReader != nil && !r.currentReader.Next() {
+               r.nextReader()
+       }
+       if r.currentReader == nil || r.err != nil {
+               return false
+       }
+       return true
+}
+func (r *concatReader) Record() arrow.Record {
+       return r.currentReader.Record()
+}
+func (r *concatReader) Err() error {
+       return r.err
+}
diff --git a/go/adbc/driver/snowflake/driver.go 
b/go/adbc/driver/snowflake/driver.go
index da49a6097..a49dd13b8 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -19,6 +19,7 @@ package snowflake
 
 import (
        "errors"
+       "maps"
        "runtime/debug"
        "strings"
 
@@ -26,7 +27,6 @@ import (
        "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
        "github.com/apache/arrow/go/v17/arrow/memory"
        "github.com/snowflakedb/gosnowflake"
-       "golang.org/x/exp/maps"
 )
 
 const (
diff --git a/go/adbc/driver/snowflake/statement.go 
b/go/adbc/driver/snowflake/statement.go
index f61db8f06..27bdb2fde 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -19,7 +19,9 @@ package snowflake
 
 import (
        "context"
+       "database/sql/driver"
        "fmt"
+       "io"
        "strconv"
        "strings"
 
@@ -463,10 +465,26 @@ func (st *statement) ExecuteQuery(ctx context.Context) 
(array.RecordReader, int6
        // concatenate RecordReaders which doesn't exist yet. let's put
        // that off for now.
        if st.streamBind != nil || st.bound != nil {
-               return nil, -1, adbc.Error{
-                       Msg:  "executing non-bulk ingest with bound params not 
yet implemented",
-                       Code: adbc.StatusNotImplemented,
+               bind := snowflakeBindReader{
+                       doQuery: func(params []driver.NamedValue) 
(array.RecordReader, error) {
+                               loader, err := st.cnxn.cn.QueryArrowStream(ctx, 
st.query, params...)
+                               if err != nil {
+                                       return nil, 
errToAdbcErr(adbc.StatusInternal, err)
+                               }
+                               return newRecordReader(ctx, st.alloc, loader, 
st.queueSize, st.prefetchConcurrency, st.useHighPrecision)
+                       },
+                       currentBatch: st.bound,
+                       stream:       st.streamBind,
+               }
+               st.bound = nil
+               st.streamBind = nil
+
+               rdr := concatReader{}
+               err := rdr.Init(&bind)
+               if err != nil {
+                       return nil, -1, err
                }
+               return &rdr, -1, nil
        }
 
        loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)
@@ -493,6 +511,38 @@ func (st *statement) ExecuteUpdate(ctx context.Context) 
(int64, error) {
                }
        }
 
+       if st.streamBind != nil || st.bound != nil {
+               numRows := int64(0)
+               bind := snowflakeBindReader{
+                       currentBatch: st.bound,
+                       stream:       st.streamBind,
+               }
+               st.bound = nil
+               st.streamBind = nil
+
+               defer bind.Release()
+               for {
+                       params, err := bind.NextParams()
+                       if err == io.EOF {
+                               break
+                       } else if err != nil {
+                               return -1, err
+                       }
+
+                       r, err := st.cnxn.cn.ExecContext(ctx, st.query, params)
+                       if err != nil {
+                               return -1, errToAdbcErr(adbc.StatusInternal, 
err)
+                       }
+                       n, err := r.RowsAffected()
+                       if err != nil {
+                               numRows = -1
+                       } else if numRows >= 0 {
+                               numRows += n
+                       }
+               }
+               return numRows, nil
+       }
+
        r, err := st.cnxn.cn.ExecContext(ctx, st.query, nil)
        if err != nil {
                return -1, errToAdbcErr(adbc.StatusIO, err)
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index 05dd4871a..ea76bcef5 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
 go 1.21
 
 require (
-       github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6
+       github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313
        github.com/bluele/gcache v0.0.2
        github.com/golang/protobuf v1.5.4
        github.com/google/uuid v1.6.0
@@ -31,7 +31,7 @@ require (
        golang.org/x/sync v0.7.0
        golang.org/x/tools v0.21.0
        google.golang.org/grpc v1.63.2
-       google.golang.org/protobuf v1.33.0
+       google.golang.org/protobuf v1.34.0
 )
 
 require (
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index 7bcebdece..f1498113f 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -20,8 +20,8 @@ github.com/andybalholm/brotli v1.1.0 
h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
 github.com/andybalholm/brotli v1.1.0/go.mod 
h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
 github.com/apache/arrow/go/v15 v15.0.0 
h1:1zZACWf85oEZY5/kd9dsQS7i+2G5zVQcbKTHgslqHNA=
 github.com/apache/arrow/go/v15 v15.0.0/go.mod 
h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA=
-github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6 
h1:FjtQoGK5idTYKI48GE49b+M5bsEzUK08rDYrzxRo1aw=
-github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6/go.mod 
h1:vihJLOeRHNQmdlAwR/1hvENgHkQJhL74WiRAX9QVmU8=
+github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313 
h1:wnD2WBKoiH6iuEuhg33RsaslZ6aqfrviadRza3bNJZ4=
+github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313/go.mod 
h1:jeCSgGamSUiG483VAAaKkPn5wa/dTCVrSmCzF6PUlEo=
 github.com/apache/thrift v0.20.0 
h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI=
 github.com/apache/thrift v0.20.0/go.mod 
h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8=
 github.com/aws/aws-sdk-go-v2 v1.25.1 
h1:P7hU6A5qEdmajGwvae/zDkOq+ULLC9tQBTwqqiwFGpI=
@@ -194,8 +194,8 @@ google.golang.org/genproto/googleapis/rpc 
v0.0.0-20240227224415-6ceb2ff114de h1:
 google.golang.org/genproto/googleapis/rpc 
v0.0.0-20240227224415-6ceb2ff114de/go.mod 
h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY=
 google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
 google.golang.org/grpc v1.63.2/go.mod 
h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
-google.golang.org/protobuf v1.33.0 
h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
-google.golang.org/protobuf v1.33.0/go.mod 
h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
+google.golang.org/protobuf v1.34.0 
h1:Qo/qEd2RZPCf2nKuorzksSknv0d3ERwp1vFG38gSmH4=
+google.golang.org/protobuf v1.34.0/go.mod 
h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod 
h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod 
h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c 
h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

Reply via email to