This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 8bd0e9bd3 fix(go/adbc/driver/snowflake): handle quotes properly (#1738)
8bd0e9bd3 is described below

commit 8bd0e9bd308a4a1d8222d4c8317e98fd39e8d5ee
Author: Matt Topol <[email protected]>
AuthorDate: Thu Apr 25 16:12:05 2024 -0400

    fix(go/adbc/driver/snowflake): handle quotes properly (#1738)
    
    fixes #1721
---
 go/adbc/driver/snowflake/bulk_ingestion.go | 24 +++++++++++++++---------
 go/adbc/driver/snowflake/connection.go     |  6 +++---
 go/adbc/driver/snowflake/driver.go         |  5 +++++
 go/adbc/driver/snowflake/driver_test.go    | 12 ++++++++----
 go/adbc/driver/snowflake/statement.go      |  6 +++---
 5 files changed, 34 insertions(+), 19 deletions(-)

diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go 
b/go/adbc/driver/snowflake/bulk_ingestion.go
index 5e1f1314f..9ec64f6e5 100644
--- a/go/adbc/driver/snowflake/bulk_ingestion.go
+++ b/go/adbc/driver/snowflake/bulk_ingestion.go
@@ -29,7 +29,6 @@ import (
        "io"
        "math"
        "runtime"
-       "strconv"
        "strings"
        "sync"
 
@@ -130,10 +129,13 @@ func (st *statement) ingestRecord(ctx context.Context) 
(nrows int64, err error)
                st.bound = nil
        }()
 
-       var initialRows int64
+       var (
+               initialRows int64
+               target      = quoteTblName(st.targetTable)
+       )
 
        // Check final row count of target table to get definitive rows affected
-       initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, 
strconv.Quote(st.targetTable))
+       initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
        if err != nil {
                st.bound.Release()
                return
@@ -182,13 +184,13 @@ func (st *statement) ingestRecord(ctx context.Context) 
(nrows int64, err error)
        }
 
        // Load the uploaded file into the target table
-       _, err = st.cnxn.cn.ExecContext(ctx, copyQuery, 
[]driver.NamedValue{{Value: strconv.Quote(st.targetTable)}})
+       _, err = st.cnxn.cn.ExecContext(ctx, copyQuery, 
[]driver.NamedValue{{Value: target}})
        if err != nil {
                return
        }
 
        // Check final row count of target table to get definitive rows affected
-       nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, 
strconv.Quote(st.targetTable))
+       nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
        nrows = nrows - initialRows
        return
 }
@@ -204,9 +206,13 @@ func (st *statement) ingestStream(ctx context.Context) 
(nrows int64, err error)
                st.streamBind = nil
        }()
 
-       var initialRows int64
+       var (
+               initialRows int64
+               target      = quoteTblName(st.targetTable)
+       )
+
        // Check final row count of target table to get definitive rows affected
-       initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, 
strconv.Quote(st.targetTable))
+       initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
        if err != nil {
                return
        }
@@ -214,7 +220,7 @@ func (st *statement) ingestStream(ctx context.Context) 
(nrows int64, err error)
        defer func() {
                // Always check the resulting row count, even in the case of an 
error. We may have ingested part of the data.
                ctx := context.Background() // TODO(joellubi): switch to 
context.WithoutCancel(ctx) once we're on Go 1.21
-               n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, 
strconv.Quote(st.targetTable))
+               n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, target)
                nrows = n - initialRows
 
                // Ingestion, row-count check, or both could have failed
@@ -268,7 +274,7 @@ func (st *statement) ingestStream(ctx context.Context) 
(nrows int64, err error)
        }
 
        // Kickoff background tasks to COPY Parquet files into Snowflake table 
as they are uploaded
-       fileReady, finishCopy, cancelCopy := runCopyTasks(ctx, st.cnxn.cn, 
strconv.Quote(st.targetTable), int(st.ingestOptions.copyConcurrency))
+       fileReady, finishCopy, cancelCopy := runCopyTasks(ctx, st.cnxn.cn, 
target, int(st.ingestOptions.copyConcurrency))
 
        // Read Parquet files from buffer pool and upload to Snowflake stage in 
parallel
        g.Go(func() error {
diff --git a/go/adbc/driver/snowflake/connection.go 
b/go/adbc/driver/snowflake/connection.go
index 41a8c1665..94223bb92 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -1212,12 +1212,12 @@ func (c *connectionImpl) getStringQuery(query string) 
(string, error) {
 func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string, 
dbSchema *string, tableName string) (*arrow.Schema, error) {
        tblParts := make([]string, 0, 3)
        if catalog != nil {
-               tblParts = append(tblParts, strconv.Quote(*catalog))
+               tblParts = append(tblParts, quoteTblName(*catalog))
        }
        if dbSchema != nil {
-               tblParts = append(tblParts, strconv.Quote(*dbSchema))
+               tblParts = append(tblParts, quoteTblName(*dbSchema))
        }
-       tblParts = append(tblParts, strconv.Quote(tableName))
+       tblParts = append(tblParts, quoteTblName(tableName))
        fullyQualifiedTable := strings.Join(tblParts, ".")
 
        rows, err := c.sqldb.QueryContext(ctx, `DESC TABLE 
`+fullyQualifiedTable)
diff --git a/go/adbc/driver/snowflake/driver.go 
b/go/adbc/driver/snowflake/driver.go
index 124c4d388..cde873c88 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -20,6 +20,7 @@ package snowflake
 import (
        "errors"
        "runtime/debug"
+       "strings"
 
        "github.com/apache/arrow-adbc/go/adbc"
        "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
@@ -163,6 +164,10 @@ func errToAdbcErr(code adbc.Status, err error) error {
        }
 }
 
+func quoteTblName(name string) string {
+       return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\""
+}
+
 type driverImpl struct {
        driverbase.DriverImplBase
 }
diff --git a/go/adbc/driver/snowflake/driver_test.go 
b/go/adbc/driver/snowflake/driver_test.go
index aa5280437..968ca942d 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -140,10 +140,14 @@ func getArr(arr arrow.Array) interface{} {
        }
 }
 
+func quoteTblName(name string) string {
+       return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\""
+}
+
 func (s *SnowflakeQuirks) CreateSampleTable(tableName string, r arrow.Record) 
error {
        var b strings.Builder
        b.WriteString("CREATE OR REPLACE TABLE ")
-       b.WriteString(strconv.Quote(tableName))
+       b.WriteString(quoteTblName(tableName))
        b.WriteString(" (")
 
        for i := 0; i < int(r.NumCols()); i++ {
@@ -164,7 +168,7 @@ func (s *SnowflakeQuirks) CreateSampleTable(tableName 
string, r arrow.Record) er
                return err
        }
 
-       insertQuery := "INSERT INTO " + strconv.Quote(tableName) + " VALUES ("
+       insertQuery := "INSERT INTO " + quoteTblName(tableName) + " VALUES ("
        bindings := strings.Repeat("?,", int(r.NumCols()))
        insertQuery += bindings[:len(bindings)-1] + ")"
 
@@ -184,7 +188,7 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, 
tblname string) error
        }
        defer stmt.Close()
 
-       if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + 
strconv.Quote(tblname)); err != nil {
+       if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + 
quoteTblName(tblname)); err != nil {
                return err
        }
 
@@ -511,7 +515,7 @@ func (suite *SnowflakeTests) 
TestSqlIngestRecordAndStreamAreEquivalent() {
        suite.NoError(err)
 
        suite.Require().NoError(suite.stmt.BindStream(suite.ctx, stream))
-       
suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, 
"bulk_ingest_bind_stream"))
+       
suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, 
"bulk_ingest_bind_stream\""))
        n, err = suite.stmt.ExecuteUpdate(suite.ctx)
        suite.Require().NoError(err)
        suite.EqualValues(3, n)
diff --git a/go/adbc/driver/snowflake/statement.go 
b/go/adbc/driver/snowflake/statement.go
index 3f446662e..dfa6b8e99 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -361,7 +361,7 @@ func (st *statement) initIngest(ctx context.Context) error {
        if st.ingestMode == adbc.OptionValueIngestModeCreateAppend {
                createBldr.WriteString(" IF NOT EXISTS ")
        }
-       createBldr.WriteString(strconv.Quote(st.targetTable))
+       createBldr.WriteString(quoteTblName(st.targetTable))
        createBldr.WriteString(" (")
 
        var schema *arrow.Schema
@@ -376,7 +376,7 @@ func (st *statement) initIngest(ctx context.Context) error {
                        createBldr.WriteString(", ")
                }
 
-               createBldr.WriteString(strconv.Quote(f.Name))
+               createBldr.WriteString(quoteTblName(f.Name))
                createBldr.WriteString(" ")
                ty := toSnowflakeType(f.Type)
                if ty == "" {
@@ -398,7 +398,7 @@ func (st *statement) initIngest(ctx context.Context) error {
        case adbc.OptionValueIngestModeAppend:
                // Do nothing
        case adbc.OptionValueIngestModeReplace:
-               replaceQuery := "DROP TABLE IF EXISTS " + 
strconv.Quote(st.targetTable)
+               replaceQuery := "DROP TABLE IF EXISTS " + 
quoteTblName(st.targetTable)
                _, err := st.cnxn.cn.ExecContext(ctx, replaceQuery, nil)
                if err != nil {
                        return errToAdbcErr(adbc.StatusInternal, err)

Reply via email to