This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 8bd0e9bd3 fix(go/adbc/driver/snowflake): handle quotes properly (#1738)
8bd0e9bd3 is described below
commit 8bd0e9bd308a4a1d8222d4c8317e98fd39e8d5ee
Author: Matt Topol <[email protected]>
AuthorDate: Thu Apr 25 16:12:05 2024 -0400
fix(go/adbc/driver/snowflake): handle quotes properly (#1738)
fixes #1721
---
go/adbc/driver/snowflake/bulk_ingestion.go | 24 +++++++++++++++---------
go/adbc/driver/snowflake/connection.go | 6 +++---
go/adbc/driver/snowflake/driver.go | 5 +++++
go/adbc/driver/snowflake/driver_test.go | 12 ++++++++----
go/adbc/driver/snowflake/statement.go | 6 +++---
5 files changed, 34 insertions(+), 19 deletions(-)
diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go
b/go/adbc/driver/snowflake/bulk_ingestion.go
index 5e1f1314f..9ec64f6e5 100644
--- a/go/adbc/driver/snowflake/bulk_ingestion.go
+++ b/go/adbc/driver/snowflake/bulk_ingestion.go
@@ -29,7 +29,6 @@ import (
"io"
"math"
"runtime"
- "strconv"
"strings"
"sync"
@@ -130,10 +129,13 @@ func (st *statement) ingestRecord(ctx context.Context)
(nrows int64, err error)
st.bound = nil
}()
- var initialRows int64
+ var (
+ initialRows int64
+ target = quoteTblName(st.targetTable)
+ )
// Check final row count of target table to get definitive rows affected
- initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb,
strconv.Quote(st.targetTable))
+ initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
if err != nil {
st.bound.Release()
return
@@ -182,13 +184,13 @@ func (st *statement) ingestRecord(ctx context.Context)
(nrows int64, err error)
}
// Load the uploaded file into the target table
- _, err = st.cnxn.cn.ExecContext(ctx, copyQuery,
[]driver.NamedValue{{Value: strconv.Quote(st.targetTable)}})
+ _, err = st.cnxn.cn.ExecContext(ctx, copyQuery,
[]driver.NamedValue{{Value: target}})
if err != nil {
return
}
// Check final row count of target table to get definitive rows affected
- nrows, err = countRowsInTable(ctx, st.cnxn.sqldb,
strconv.Quote(st.targetTable))
+ nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
nrows = nrows - initialRows
return
}
@@ -204,9 +206,13 @@ func (st *statement) ingestStream(ctx context.Context)
(nrows int64, err error)
st.streamBind = nil
}()
- var initialRows int64
+ var (
+ initialRows int64
+ target = quoteTblName(st.targetTable)
+ )
+
// Check final row count of target table to get definitive rows affected
- initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb,
strconv.Quote(st.targetTable))
+ initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
if err != nil {
return
}
@@ -214,7 +220,7 @@ func (st *statement) ingestStream(ctx context.Context)
(nrows int64, err error)
defer func() {
// Always check the resulting row count, even in the case of an
error. We may have ingested part of the data.
ctx := context.Background() // TODO(joellubi): switch to
context.WithoutCancel(ctx) once we're on Go 1.21
- n, countErr := countRowsInTable(ctx, st.cnxn.sqldb,
strconv.Quote(st.targetTable))
+ n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, target)
nrows = n - initialRows
// Ingestion, row-count check, or both could have failed
@@ -268,7 +274,7 @@ func (st *statement) ingestStream(ctx context.Context)
(nrows int64, err error)
}
// Kickoff background tasks to COPY Parquet files into Snowflake table
as they are uploaded
- fileReady, finishCopy, cancelCopy := runCopyTasks(ctx, st.cnxn.cn,
strconv.Quote(st.targetTable), int(st.ingestOptions.copyConcurrency))
+ fileReady, finishCopy, cancelCopy := runCopyTasks(ctx, st.cnxn.cn,
target, int(st.ingestOptions.copyConcurrency))
// Read Parquet files from buffer pool and upload to Snowflake stage in
parallel
g.Go(func() error {
diff --git a/go/adbc/driver/snowflake/connection.go
b/go/adbc/driver/snowflake/connection.go
index 41a8c1665..94223bb92 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -1212,12 +1212,12 @@ func (c *connectionImpl) getStringQuery(query string)
(string, error) {
func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string,
dbSchema *string, tableName string) (*arrow.Schema, error) {
tblParts := make([]string, 0, 3)
if catalog != nil {
- tblParts = append(tblParts, strconv.Quote(*catalog))
+ tblParts = append(tblParts, quoteTblName(*catalog))
}
if dbSchema != nil {
- tblParts = append(tblParts, strconv.Quote(*dbSchema))
+ tblParts = append(tblParts, quoteTblName(*dbSchema))
}
- tblParts = append(tblParts, strconv.Quote(tableName))
+ tblParts = append(tblParts, quoteTblName(tableName))
fullyQualifiedTable := strings.Join(tblParts, ".")
rows, err := c.sqldb.QueryContext(ctx, `DESC TABLE
`+fullyQualifiedTable)
diff --git a/go/adbc/driver/snowflake/driver.go
b/go/adbc/driver/snowflake/driver.go
index 124c4d388..cde873c88 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -20,6 +20,7 @@ package snowflake
import (
"errors"
"runtime/debug"
+ "strings"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
@@ -163,6 +164,10 @@ func errToAdbcErr(code adbc.Status, err error) error {
}
}
+func quoteTblName(name string) string {
+ return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\""
+}
+
type driverImpl struct {
driverbase.DriverImplBase
}
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index aa5280437..968ca942d 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -140,10 +140,14 @@ func getArr(arr arrow.Array) interface{} {
}
}
+func quoteTblName(name string) string {
+ return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\""
+}
+
func (s *SnowflakeQuirks) CreateSampleTable(tableName string, r arrow.Record)
error {
var b strings.Builder
b.WriteString("CREATE OR REPLACE TABLE ")
- b.WriteString(strconv.Quote(tableName))
+ b.WriteString(quoteTblName(tableName))
b.WriteString(" (")
for i := 0; i < int(r.NumCols()); i++ {
@@ -164,7 +168,7 @@ func (s *SnowflakeQuirks) CreateSampleTable(tableName
string, r arrow.Record) er
return err
}
- insertQuery := "INSERT INTO " + strconv.Quote(tableName) + " VALUES ("
+ insertQuery := "INSERT INTO " + quoteTblName(tableName) + " VALUES ("
bindings := strings.Repeat("?,", int(r.NumCols()))
insertQuery += bindings[:len(bindings)-1] + ")"
@@ -184,7 +188,7 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection,
tblname string) error
}
defer stmt.Close()
- if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` +
strconv.Quote(tblname)); err != nil {
+ if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` +
quoteTblName(tblname)); err != nil {
return err
}
@@ -511,7 +515,7 @@ func (suite *SnowflakeTests)
TestSqlIngestRecordAndStreamAreEquivalent() {
suite.NoError(err)
suite.Require().NoError(suite.stmt.BindStream(suite.ctx, stream))
-
suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable,
"bulk_ingest_bind_stream"))
+
suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable,
"bulk_ingest_bind_stream\""))
n, err = suite.stmt.ExecuteUpdate(suite.ctx)
suite.Require().NoError(err)
suite.EqualValues(3, n)
diff --git a/go/adbc/driver/snowflake/statement.go
b/go/adbc/driver/snowflake/statement.go
index 3f446662e..dfa6b8e99 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -361,7 +361,7 @@ func (st *statement) initIngest(ctx context.Context) error {
if st.ingestMode == adbc.OptionValueIngestModeCreateAppend {
createBldr.WriteString(" IF NOT EXISTS ")
}
- createBldr.WriteString(strconv.Quote(st.targetTable))
+ createBldr.WriteString(quoteTblName(st.targetTable))
createBldr.WriteString(" (")
var schema *arrow.Schema
@@ -376,7 +376,7 @@ func (st *statement) initIngest(ctx context.Context) error {
createBldr.WriteString(", ")
}
- createBldr.WriteString(strconv.Quote(f.Name))
+ createBldr.WriteString(quoteTblName(f.Name))
createBldr.WriteString(" ")
ty := toSnowflakeType(f.Type)
if ty == "" {
@@ -398,7 +398,7 @@ func (st *statement) initIngest(ctx context.Context) error {
case adbc.OptionValueIngestModeAppend:
// Do nothing
case adbc.OptionValueIngestModeReplace:
- replaceQuery := "DROP TABLE IF EXISTS " +
strconv.Quote(st.targetTable)
+ replaceQuery := "DROP TABLE IF EXISTS " +
quoteTblName(st.targetTable)
_, err := st.cnxn.cn.ExecContext(ctx, replaceQuery, nil)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)