This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 8f0bdb367 feat(go/adbc/driver/snowflake): support parameter binding
(#1808)
8f0bdb367 is described below
commit 8f0bdb367708f36183e7aa3691ff14c17c28d4e2
Author: David Li <[email protected]>
AuthorDate: Tue May 7 23:11:25 2024 +0900
feat(go/adbc/driver/snowflake): support parameter binding (#1808)
Fixes #1144.
---
c/driver/snowflake/snowflake_test.cc | 2 +-
go/adbc/driver/snowflake/binding.go | 153 ++++++++++++++++++++++++++++++
go/adbc/driver/snowflake/concat_reader.go | 107 +++++++++++++++++++++
go/adbc/driver/snowflake/driver.go | 2 +-
go/adbc/driver/snowflake/statement.go | 56 ++++++++++-
go/adbc/go.mod | 4 +-
go/adbc/go.sum | 8 +-
7 files changed, 321 insertions(+), 11 deletions(-)
diff --git a/c/driver/snowflake/snowflake_test.cc
b/c/driver/snowflake/snowflake_test.cc
index 0fe07ecbd..a4d742491 100644
--- a/c/driver/snowflake/snowflake_test.cc
+++ b/c/driver/snowflake/snowflake_test.cc
@@ -146,7 +146,7 @@ class SnowflakeQuirks : public
adbc_validation::DriverQuirks {
bool supports_metadata_current_catalog() const override { return false; }
bool supports_metadata_current_db_schema() const override { return false; }
bool supports_partitioned_data() const override { return false; }
- bool supports_dynamic_parameter_binding() const override { return false; }
+ bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
std::string db_schema() const override { return schema_; }
diff --git a/go/adbc/driver/snowflake/binding.go
b/go/adbc/driver/snowflake/binding.go
new file mode 100644
index 000000000..7f6878945
--- /dev/null
+++ b/go/adbc/driver/snowflake/binding.go
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+ "database/sql"
+ "database/sql/driver"
+ "fmt"
+ "io"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow/go/v17/arrow"
+ "github.com/apache/arrow/go/v17/arrow/array"
+)
+
+func convertArrowToNamedValue(batch arrow.Record, index int)
([]driver.NamedValue, error) {
+ // see goTypeToSnowflake in gosnowflake
+ // technically, snowflake can bind an array of values at once, but
+ // only for INSERT, so we can't take advantage of that without
+ // analyzing the query ourselves
+ params := make([]driver.NamedValue, batch.NumCols())
+ for i, field := range batch.Schema().Fields() {
+ rawColumn := batch.Column(i)
+ params[i].Ordinal = i + 1
+ switch column := rawColumn.(type) {
+ case *array.Boolean:
+ params[i].Value = sql.NullBool{
+ Bool: column.Value(index),
+ Valid: column.IsValid(index),
+ }
+ case *array.Float32:
+ // Snowflake only recognizes float64
+ params[i].Value = sql.NullFloat64{
+ Float64: float64(column.Value(index)),
+ Valid: column.IsValid(index),
+ }
+ case *array.Float64:
+ params[i].Value = sql.NullFloat64{
+ Float64: column.Value(index),
+ Valid: column.IsValid(index),
+ }
+ case *array.Int8:
+ // Snowflake only recognizes int64
+ params[i].Value = sql.NullInt64{
+ Int64: int64(column.Value(index)),
+ Valid: column.IsValid(index),
+ }
+ case *array.Int16:
+ params[i].Value = sql.NullInt64{
+ Int64: int64(column.Value(index)),
+ Valid: column.IsValid(index),
+ }
+ case *array.Int32:
+ params[i].Value = sql.NullInt64{
+ Int64: int64(column.Value(index)),
+ Valid: column.IsValid(index),
+ }
+ case *array.Int64:
+ params[i].Value = sql.NullInt64{
+ Int64: column.Value(index),
+ Valid: column.IsValid(index),
+ }
+ case *array.String:
+ params[i].Value = sql.NullString{
+ String: column.Value(index),
+ Valid: column.IsValid(index),
+ }
+ case *array.LargeString:
+ params[i].Value = sql.NullString{
+ String: column.Value(index),
+ Valid: column.IsValid(index),
+ }
+ default:
+ return nil, adbc.Error{
+ Code: adbc.StatusNotImplemented,
+ Msg: fmt.Sprintf("[Snowflake] Unsupported bind
param '%s' type %s", field.Name, field.Type.String()),
+ }
+ }
+ }
+ return params, nil
+}
+
+type snowflakeBindReader struct {
+ doQuery func([]driver.NamedValue) (array.RecordReader, error)
+ currentBatch arrow.Record
+ nextIndex int64
+ // may be nil if we bound only a batch
+ stream array.RecordReader
+}
+
+func (r *snowflakeBindReader) Release() {
+ if r.currentBatch != nil {
+ r.currentBatch.Release()
+ r.currentBatch = nil
+ }
+ if r.stream != nil {
+ r.stream.Release()
+ r.stream = nil
+ }
+}
+
+func (r *snowflakeBindReader) Next() (array.RecordReader, error) {
+ params, err := r.NextParams()
+ if err != nil {
+ // includes EOF
+ return nil, err
+ }
+ return r.doQuery(params)
+}
+
+func (r *snowflakeBindReader) NextParams() ([]driver.NamedValue, error) {
+ for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
+ // We can be used both by binding a stream or by binding a
+ // batch. In the latter case, we have to release the batch,
+ // but not in the former case. Unify the cases by always
+ // releasing the batch, adding an "extra" retain so that the
+ // release does not cause issues.
+ if r.currentBatch != nil {
+ r.currentBatch.Release()
+ }
+ r.currentBatch = nil
+ if r.stream != nil && r.stream.Next() {
+ r.currentBatch = r.stream.Record()
+ r.currentBatch.Retain()
+ r.nextIndex = 0
+ continue
+ } else if r.stream != nil && r.stream.Err() != nil {
+ return nil, r.stream.Err()
+ } else {
+ // no more params
+ return nil, io.EOF
+ }
+ }
+
+ params, err := convertArrowToNamedValue(r.currentBatch,
int(r.nextIndex))
+ r.nextIndex++
+ return params, err
+}
diff --git a/go/adbc/driver/snowflake/concat_reader.go
b/go/adbc/driver/snowflake/concat_reader.go
new file mode 100644
index 000000000..c04a81a21
--- /dev/null
+++ b/go/adbc/driver/snowflake/concat_reader.go
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+ "io"
+ "sync/atomic"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow/go/v17/arrow"
+ "github.com/apache/arrow/go/v17/arrow/array"
+)
+
+type readerIter interface {
+ Release()
+
+ Next() (array.RecordReader, error)
+}
+
+type concatReader struct {
+ refCount atomic.Int64
+ readers readerIter
+ currentReader array.RecordReader
+ schema *arrow.Schema
+ err error
+}
+
+func (r *concatReader) nextReader() {
+ if r.currentReader != nil {
+ r.currentReader.Release()
+ r.currentReader = nil
+ }
+ reader, err := r.readers.Next()
+ if err == io.EOF {
+ r.currentReader = nil
+ } else if err != nil {
+ r.err = err
+ } else {
+ // May be nil
+ r.currentReader = reader
+ }
+}
+func (r *concatReader) Init(readers readerIter) error {
+ r.readers = readers
+ r.refCount.Store(1)
+ r.nextReader()
+ if r.err != nil {
+ r.Release()
+ return r.err
+ } else if r.currentReader == nil {
+ r.Release()
+ r.err = adbc.Error{
+ Code: adbc.StatusInternal,
+ Msg: "[Snowflake] No data in this stream",
+ }
+ return r.err
+ }
+ r.schema = r.currentReader.Schema()
+ return nil
+}
+func (r *concatReader) Retain() {
+ r.refCount.Add(1)
+}
+func (r *concatReader) Release() {
+ if r.refCount.Add(-1) == 0 {
+ if r.currentReader != nil {
+ r.currentReader.Release()
+ }
+ r.readers.Release()
+ }
+}
+func (r *concatReader) Schema() *arrow.Schema {
+ if r.schema == nil {
+ panic("did not call concatReader.Init")
+ }
+ return r.schema
+}
+func (r *concatReader) Next() bool {
+ for r.currentReader != nil && !r.currentReader.Next() {
+ r.nextReader()
+ }
+ if r.currentReader == nil || r.err != nil {
+ return false
+ }
+ return true
+}
+func (r *concatReader) Record() arrow.Record {
+ return r.currentReader.Record()
+}
+func (r *concatReader) Err() error {
+ return r.err
+}
diff --git a/go/adbc/driver/snowflake/driver.go
b/go/adbc/driver/snowflake/driver.go
index da49a6097..a49dd13b8 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -19,6 +19,7 @@ package snowflake
import (
"errors"
+ "maps"
"runtime/debug"
"strings"
@@ -26,7 +27,6 @@ import (
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/snowflakedb/gosnowflake"
- "golang.org/x/exp/maps"
)
const (
diff --git a/go/adbc/driver/snowflake/statement.go
b/go/adbc/driver/snowflake/statement.go
index f61db8f06..27bdb2fde 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -19,7 +19,9 @@ package snowflake
import (
"context"
+ "database/sql/driver"
"fmt"
+ "io"
"strconv"
"strings"
@@ -463,10 +465,26 @@ func (st *statement) ExecuteQuery(ctx context.Context)
(array.RecordReader, int6
// concatenate RecordReaders which doesn't exist yet. let's put
// that off for now.
if st.streamBind != nil || st.bound != nil {
- return nil, -1, adbc.Error{
- Msg: "executing non-bulk ingest with bound params not
yet implemented",
- Code: adbc.StatusNotImplemented,
+ bind := snowflakeBindReader{
+ doQuery: func(params []driver.NamedValue)
(array.RecordReader, error) {
+ loader, err := st.cnxn.cn.QueryArrowStream(ctx,
st.query, params...)
+ if err != nil {
+ return nil,
errToAdbcErr(adbc.StatusInternal, err)
+ }
+ return newRecordReader(ctx, st.alloc, loader,
st.queueSize, st.prefetchConcurrency, st.useHighPrecision)
+ },
+ currentBatch: st.bound,
+ stream: st.streamBind,
+ }
+ st.bound = nil
+ st.streamBind = nil
+
+ rdr := concatReader{}
+ err := rdr.Init(&bind)
+ if err != nil {
+ return nil, -1, err
}
+ return &rdr, -1, nil
}
loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)
@@ -493,6 +511,38 @@ func (st *statement) ExecuteUpdate(ctx context.Context)
(int64, error) {
}
}
+ if st.streamBind != nil || st.bound != nil {
+ numRows := int64(0)
+ bind := snowflakeBindReader{
+ currentBatch: st.bound,
+ stream: st.streamBind,
+ }
+ st.bound = nil
+ st.streamBind = nil
+
+ defer bind.Release()
+ for {
+ params, err := bind.NextParams()
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return -1, err
+ }
+
+ r, err := st.cnxn.cn.ExecContext(ctx, st.query, params)
+ if err != nil {
+ return -1, errToAdbcErr(adbc.StatusInternal,
err)
+ }
+ n, err := r.RowsAffected()
+ if err != nil {
+ numRows = -1
+ } else if numRows >= 0 {
+ numRows += n
+ }
+ }
+ return numRows, nil
+ }
+
r, err := st.cnxn.cn.ExecContext(ctx, st.query, nil)
if err != nil {
return -1, errToAdbcErr(adbc.StatusIO, err)
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index 05dd4871a..ea76bcef5 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
go 1.21
require (
- github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6
+ github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313
github.com/bluele/gcache v0.0.2
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
@@ -31,7 +31,7 @@ require (
golang.org/x/sync v0.7.0
golang.org/x/tools v0.21.0
google.golang.org/grpc v1.63.2
- google.golang.org/protobuf v1.33.0
+ google.golang.org/protobuf v1.34.0
)
require (
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index 7bcebdece..f1498113f 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -20,8 +20,8 @@ github.com/andybalholm/brotli v1.1.0
h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
github.com/andybalholm/brotli v1.1.0/go.mod
h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/apache/arrow/go/v15 v15.0.0
h1:1zZACWf85oEZY5/kd9dsQS7i+2G5zVQcbKTHgslqHNA=
github.com/apache/arrow/go/v15 v15.0.0/go.mod
h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA=
-github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6
h1:FjtQoGK5idTYKI48GE49b+M5bsEzUK08rDYrzxRo1aw=
-github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6/go.mod
h1:vihJLOeRHNQmdlAwR/1hvENgHkQJhL74WiRAX9QVmU8=
+github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313
h1:wnD2WBKoiH6iuEuhg33RsaslZ6aqfrviadRza3bNJZ4=
+github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313/go.mod
h1:jeCSgGamSUiG483VAAaKkPn5wa/dTCVrSmCzF6PUlEo=
github.com/apache/thrift v0.20.0
h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI=
github.com/apache/thrift v0.20.0/go.mod
h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8=
github.com/aws/aws-sdk-go-v2 v1.25.1
h1:P7hU6A5qEdmajGwvae/zDkOq+ULLC9tQBTwqqiwFGpI=
@@ -194,8 +194,8 @@ google.golang.org/genproto/googleapis/rpc
v0.0.0-20240227224415-6ceb2ff114de h1:
google.golang.org/genproto/googleapis/rpc
v0.0.0-20240227224415-6ceb2ff114de/go.mod
h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY=
google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
google.golang.org/grpc v1.63.2/go.mod
h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
-google.golang.org/protobuf v1.33.0
h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
-google.golang.org/protobuf v1.33.0/go.mod
h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
+google.golang.org/protobuf v1.34.0
h1:Qo/qEd2RZPCf2nKuorzksSknv0d3ERwp1vFG38gSmH4=
+google.golang.org/protobuf v1.34.0/go.mod
h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod
h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod
h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=