This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 7706ace06 fix(go/adbc/driver/snowflake): use one session for 
connection (#2494)
7706ace06 is described below

commit 7706ace06a4da083693cd329877504b9b4d9220f
Author: Matt Topol <[email protected]>
AuthorDate: Tue Feb 18 12:24:49 2025 -0500

    fix(go/adbc/driver/snowflake): use one session for connection (#2494)
    
    Instead of creating a separate connection for simple queries like SELECT
    COUNT and the metadata queries, we'll keep using the single connection
    to maintain the existing session throughout. This should avoid the
    problem mentioned in #2128 where a new session cancels the remaining
    COPY INTO queries. It also simplifies some of the work on metadata
    queries so that we don't need to explicitly propagate catalog, schema
    and database names between connections.
    
    Fixes #2517
---
 go/adbc/driver/snowflake/bulk_ingestion.go     |  42 ++++--
 go/adbc/driver/snowflake/connection.go         | 179 ++++++++++++-------------
 go/adbc/driver/snowflake/snowflake_database.go |   2 -
 3 files changed, 114 insertions(+), 109 deletions(-)

diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go 
b/go/adbc/driver/snowflake/bulk_ingestion.go
index 4884d039b..cdce01461 100644
--- a/go/adbc/driver/snowflake/bulk_ingestion.go
+++ b/go/adbc/driver/snowflake/bulk_ingestion.go
@@ -22,7 +22,6 @@ import (
        "bytes"
        "compress/flate"
        "context"
-       "database/sql"
        "database/sql/driver"
        "errors"
        "fmt"
@@ -31,6 +30,7 @@ import (
        "path"
        "runtime"
        "slices"
+       "strconv"
        "strings"
        "sync"
        "time"
@@ -140,7 +140,7 @@ func (st *statement) ingestRecord(ctx context.Context) 
(nrows int64, err error)
        )
 
        // Check final row count of target table to get definitive rows affected
-       initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
+       initialRows, err = countRowsInTable(ctx, st.cnxn.cn, target)
        if err != nil {
                st.bound.Release()
                return
@@ -195,7 +195,7 @@ func (st *statement) ingestRecord(ctx context.Context) 
(nrows int64, err error)
        }
 
        // Check final row count of target table to get definitive rows affected
-       nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
+       nrows, err = countRowsInTable(ctx, st.cnxn.cn, target)
        nrows = nrows - initialRows
        return
 }
@@ -217,7 +217,7 @@ func (st *statement) ingestStream(ctx context.Context) 
(nrows int64, err error)
        )
 
        // Check final row count of target table to get definitive rows affected
-       initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, target)
+       initialRows, err = countRowsInTable(ctx, st.cnxn.cn, target)
        if err != nil {
                return
        }
@@ -225,7 +225,7 @@ func (st *statement) ingestStream(ctx context.Context) 
(nrows int64, err error)
        defer func() {
                // Always check the resulting row count, even in the case of an 
error. We may have ingested part of the data.
                ctx := context.WithoutCancel(ctx)
-               n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, target)
+               n, countErr := countRowsInTable(ctx, st.cnxn.cn, target)
                nrows = n - initialRows
 
                // Ingestion, row-count check, or both could have failed
@@ -321,7 +321,7 @@ func readRecords(ctx context.Context, rdr 
array.RecordReader, out chan<- arrow.R
                }
        }
 
-       return nil
+       return rdr.Err()
 }
 
 func writeRecordToParquet(wr *pqarrow.FileWriter, rec arrow.Record) (int64, 
error) {
@@ -536,6 +536,7 @@ func runCopyTasks(ctx context.Context, cn snowflakeConn, 
tableName string, concu
        g, ctx := errgroup.WithContext(ctx)
        g.SetLimit(concurrency)
 
+       done := make(chan struct{})
        readyCh := make(chan struct{}, 1)
        stopCh := make(chan interface{})
 
@@ -570,6 +571,17 @@ func runCopyTasks(ctx context.Context, cn snowflakeConn, 
tableName string, concu
                close(stopCh)
                close(readyCh)
 
+               <-done
+
+               // wait for any currently running copies to finish before we 
continue
+               if err := g.Wait(); err != nil {
+                       return err
+               }
+
+               if filesToCopy.Len() == 0 {
+                       return nil
+               }
+
                maxRetries := 5 // maybe make configurable?
                for attempt := 0; attempt < maxRetries+1; attempt++ {
                        if attempt > 0 {
@@ -585,7 +597,7 @@ func runCopyTasks(ctx context.Context, cn snowflakeConn, 
tableName string, concu
 
                        if filesToCopy.Len() == 0 {
                                // all files successfully copied
-                               return g.Wait()
+                               return nil
                        }
                }
 
@@ -601,6 +613,7 @@ func runCopyTasks(ctx context.Context, cn snowflakeConn, 
tableName string, concu
        }
 
        go func() {
+               defer close(done)
                for {
 
                        // Block until there is at least 1 new file available 
for copy, or it's time to shutdown
@@ -625,15 +638,20 @@ func runCopyTasks(ctx context.Context, cn snowflakeConn, 
tableName string, concu
        return readyFn, stopFn, cancelFn
 }
 
-func countRowsInTable(ctx context.Context, db *sql.DB, tableName string) 
(int64, error) {
-       var nrows int64
+func countRowsInTable(ctx context.Context, db snowflakeConn, tableName string) 
(int64, error) {
+       rows, err := db.QueryContext(ctx, countQuery, 
[]driver.NamedValue{{Value: tableName}})
+       if err != nil {
+               return 0, errToAdbcErr(adbc.StatusIO, err)
+       }
+       defer rows.Close()
 
-       row := db.QueryRowContext(ctx, countQuery, tableName)
-       if err := row.Scan(&nrows); err != nil {
+       dest := make([]driver.Value, 1)
+       if err := rows.Next(dest); err != nil {
                return 0, errToAdbcErr(adbc.StatusIO, err)
        }
 
-       return nrows, nil
+       n, err := strconv.Atoi(dest[0].(string))
+       return int64(n), err
 }
 
 // Initializes a sync.Pool of *bytes.Buffer.
diff --git a/go/adbc/driver/snowflake/connection.go 
b/go/adbc/driver/snowflake/connection.go
index 2658e6f09..c8dcae424 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -22,6 +22,7 @@ import (
        "database/sql"
        "database/sql/driver"
        "embed"
+       "errors"
        "fmt"
        "io"
        "io/fs"
@@ -66,10 +67,9 @@ type snowflakeConn interface {
 type connectionImpl struct {
        driverbase.ConnectionImplBase
 
-       cn    snowflakeConn
-       db    *databaseImpl
-       ctor  driver.Connector
-       sqldb *sql.DB
+       cn   snowflakeConn
+       db   *databaseImpl
+       ctor driver.Connector
 
        activeTransaction bool
        useHighPrecision  bool
@@ -103,8 +103,8 @@ func escapeSingleQuoteForLike(arg string) string {
        }
 }
 
-func getQueryID(ctx context.Context, query string, driverConn any) (string, 
error) {
-       rows, err := driverConn.(driver.QueryerContext).QueryContext(ctx, 
query, nil)
+func getQueryID(ctx context.Context, query string, driverConn 
driver.QueryerContext) (string, error) {
+       rows, err := driverConn.QueryContext(ctx, query, nil)
        if err != nil {
                return "", err
        }
@@ -127,42 +127,41 @@ func addLike(query string, pattern *string) string {
        return query
 }
 
-func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, 
objType string, catalog, dbSchema, tableName *string, outQueryID *string) {
+func goGetQueryID(ctx context.Context, conn driver.QueryerContext, grp 
*errgroup.Group, objType string, catalog, dbSchema, tableName *string, 
outQueryID *string) {
        grp.Go(func() error {
-               return conn.Raw(func(driverConn any) (err error) {
-                       query := "SHOW TERSE /* ADBC:getObjects */ " + objType
-                       switch objType {
-                       case objDatabases:
-                               query = addLike(query, catalog)
+               query := "SHOW TERSE /* ADBC:getObjects */ " + objType
+               switch objType {
+               case objDatabases:
+                       query = addLike(query, catalog)
+                       query += " IN ACCOUNT"
+               case objSchemas:
+                       query = addLike(query, dbSchema)
+
+                       if catalog == nil || isWildcardStr(*catalog) {
                                query += " IN ACCOUNT"
-                       case objSchemas:
-                               query = addLike(query, dbSchema)
-
-                               if catalog == nil || isWildcardStr(*catalog) {
-                                       query += " IN ACCOUNT"
-                               } else {
-                                       query += " IN DATABASE " + 
quoteTblName(*catalog)
-                               }
-                       case objViews, objTables, objObjects:
-                               query = addLike(query, tableName)
+                       } else {
+                               query += " IN DATABASE " + 
quoteTblName(*catalog)
+                       }
+               case objViews, objTables, objObjects:
+                       query = addLike(query, tableName)
 
-                               if catalog == nil || isWildcardStr(*catalog) {
-                                       query += " IN ACCOUNT"
+                       if catalog == nil || isWildcardStr(*catalog) {
+                               query += " IN ACCOUNT"
+                       } else {
+                               escapedCatalog := quoteTblName(*catalog)
+                               if dbSchema == nil || isWildcardStr(*dbSchema) {
+                                       query += " IN DATABASE " + 
escapedCatalog
                                } else {
-                                       escapedCatalog := quoteTblName(*catalog)
-                                       if dbSchema == nil || 
isWildcardStr(*dbSchema) {
-                                               query += " IN DATABASE " + 
escapedCatalog
-                                       } else {
-                                               query += " IN SCHEMA " + 
escapedCatalog + "." + quoteTblName(*dbSchema)
-                                       }
+                                       query += " IN SCHEMA " + escapedCatalog 
+ "." + quoteTblName(*dbSchema)
                                }
-                       default:
-                               return fmt.Errorf("unimplemented object type")
                        }
+               default:
+                       return fmt.Errorf("unimplemented object type")
+               }
 
-                       *outQueryID, err = getQueryID(ctx, query, driverConn)
-                       return
-               })
+               var err error
+               *outQueryID, err = getQueryID(ctx, query, conn)
+               return err
        })
 }
 
@@ -176,12 +175,7 @@ func (c *connectionImpl) GetObjects(ctx context.Context, 
depth adbc.ObjectDepth,
                showSchemaQueryID, tableQueryID                     string
        )
 
-       conn, err := c.sqldb.Conn(ctx)
-       if err != nil {
-               return nil, err
-       }
-       defer conn.Close()
-
+       conn := c.cn
        var hasViews, hasTables bool
        for _, t := range tableType {
                if strings.EqualFold("VIEW", t) {
@@ -253,25 +247,19 @@ func (c *connectionImpl) GetObjects(ctx context.Context, 
depth adbc.ObjectDepth,
 
                // Detailed constraint info not available in information_schema
                // Need to dispatch SHOW queries and use conn.Raw to extract 
the queryID for reuse in GetObjects query
-               gQueryIDs.Go(func() error {
-                       return conn.Raw(func(driverConn any) (err error) {
-                               pkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW 
PRIMARY KEYS /* ADBC:getObjectsTables */"+suffix, driverConn)
-                               return err
-                       })
+               gQueryIDs.Go(func() (err error) {
+                       pkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW PRIMARY 
KEYS /* ADBC:getObjectsTables */"+suffix, conn)
+                       return err
                })
 
-               gQueryIDs.Go(func() error {
-                       return conn.Raw(func(driverConn any) (err error) {
-                               fkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW 
IMPORTED KEYS /* ADBC:getObjectsTables */"+suffix, driverConn)
-                               return err
-                       })
+               gQueryIDs.Go(func() (err error) {
+                       fkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW 
IMPORTED KEYS /* ADBC:getObjectsTables */"+suffix, conn)
+                       return err
                })
 
-               gQueryIDs.Go(func() error {
-                       return conn.Raw(func(driverConn any) (err error) {
-                               uniqueQueryID, err = getQueryID(gQueryIDsCtx, 
"SHOW UNIQUE KEYS /* ADBC:getObjectsTables */"+suffix, driverConn)
-                               return err
-                       })
+               gQueryIDs.Go(func() (err error) {
+                       uniqueQueryID, err = getQueryID(gQueryIDsCtx, "SHOW 
UNIQUE KEYS /* ADBC:getObjectsTables */"+suffix, conn)
+                       return err
                })
 
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
@@ -301,7 +289,7 @@ func (c *connectionImpl) GetObjects(ctx context.Context, 
depth adbc.ObjectDepth,
                return nil, err
        }
 
-       args := []any{
+       args := []sql.NamedArg{
                // Optional filter patterns
                driverbase.PatternToNamedArg("CATALOG", catalog),
                driverbase.PatternToNamedArg("DB_SCHEMA", dbSchema),
@@ -318,32 +306,17 @@ func (c *connectionImpl) GetObjects(ctx context.Context, 
depth adbc.ObjectDepth,
                sql.Named("SHOW_TABLE_QUERY_ID", tableQueryID),
        }
 
-       // currently only the Columns / all case still requires a current 
database/schema
-       // to be propagated. The rest of the cases all solely use SHOW queries 
for the metadata
-       // just as done by the snowflake JDBC driver. In those cases we don't 
need to propagate
-       // the current session database/schema.
-       if depth == adbc.ObjectDepthColumns || depth == adbc.ObjectDepthAll {
-               dbname, err := c.GetCurrentCatalog()
-               if err != nil {
-                       return nil, errToAdbcErr(adbc.StatusIO, err)
-               }
-
-               schemaname, err := c.GetCurrentDbSchema()
-               if err != nil {
-                       return nil, errToAdbcErr(adbc.StatusIO, err)
-               }
-
-               // the connection that is used is not the same connection 
context where the database may have been set
-               // if the caller called SetCurrentCatalog() so need to ensure 
the database context is appropriate
-               multiCtx, _ := gosnowflake.WithMultiStatement(ctx, 2)
-               _, err = conn.ExecContext(multiCtx, fmt.Sprintf("USE DATABASE 
%s; USE SCHEMA %s;", quoteTblName(dbname), quoteTblName(schemaname)))
-               if err != nil {
-                       return nil, errToAdbcErr(adbc.StatusIO, err)
+       nvargs := make([]driver.NamedValue, len(args))
+       for i, arg := range args {
+               nvargs[i] = driver.NamedValue{
+                       Name:    arg.Name,
+                       Ordinal: i + 1,
+                       Value:   arg.Value,
                }
        }
 
        query := string(queryBytes)
-       rows, err := conn.QueryContext(ctx, query, args...)
+       rows, err := conn.QueryContext(ctx, query, nvargs)
        if err != nil {
                return nil, errToAdbcErr(adbc.StatusIO, err)
        }
@@ -354,9 +327,18 @@ func (c *connectionImpl) GetObjects(ctx context.Context, 
depth adbc.ObjectDepth,
 
        go func() {
                defer close(catalogCh)
-               for rows.Next() {
+               dest := make([]driver.Value, len(rows.Columns()))
+               for {
+                       if err := rows.Next(dest); err != nil {
+                               if errors.Is(err, io.EOF) {
+                                       return
+                               }
+                               errCh <- errToAdbcErr(adbc.StatusInvalidData, 
err)
+                               return
+                       }
+
                        var getObjectsCatalog driverbase.GetObjectsInfo
-                       if err := rows.Scan(&getObjectsCatalog); err != nil {
+                       if err := getObjectsCatalog.Scan(dest[0]); err != nil {
                                errCh <- errToAdbcErr(adbc.StatusInvalidData, 
err)
                                return
                        }
@@ -630,22 +612,34 @@ func (c *connectionImpl) GetTableSchema(ctx 
context.Context, catalog *string, db
        tblParts = append(tblParts, quoteTblName(tableName))
        fullyQualifiedTable := strings.Join(tblParts, ".")
 
-       rows, err := c.sqldb.QueryContext(ctx, `DESC TABLE 
`+fullyQualifiedTable)
+       rows, err := c.cn.QueryContext(ctx, `DESC TABLE `+fullyQualifiedTable, 
nil)
        if err != nil {
                return nil, errToAdbcErr(adbc.StatusIO, err)
        }
        defer rows.Close()
 
        var (
-               name, typ, kind, isnull, primary, unique          string
-               def, check, expr, comment, policyName, privDomain sql.NullString
-               fields                                            = 
[]arrow.Field{}
+               name, typ, isnull, primary string
+               comment                    sql.NullString
+               fields                     = []arrow.Field{}
        )
 
-       for rows.Next() {
-               err := rows.Scan(&name, &typ, &kind, &isnull, &def, &primary, 
&unique,
-                       &check, &expr, &comment, &policyName, &privDomain)
-               if err != nil {
+       // columns are:
+       // name, type, kind, isnull, primary, unique, def, check, expr, 
comment, policyName, privDomain
+       dest := make([]driver.Value, len(rows.Columns()))
+       for {
+               if err := rows.Next(dest); err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break
+                       }
+                       return nil, errToAdbcErr(adbc.StatusIO, err)
+               }
+
+               name = dest[0].(string)
+               typ = dest[1].(string)
+               isnull = dest[3].(string)
+               primary = dest[5].(string)
+               if err := comment.Scan(dest[9]); err != nil {
                        return nil, errToAdbcErr(adbc.StatusIO, err)
                }
 
@@ -703,15 +697,10 @@ func (c *connectionImpl) NewStatement() (adbc.Statement, 
error) {
 
 // Close closes this connection and releases any associated resources.
 func (c *connectionImpl) Close() error {
-       if c.sqldb == nil || c.cn == nil {
+       if c.cn == nil {
                return adbc.Error{Code: adbc.StatusInvalidState}
        }
 
-       if err := c.sqldb.Close(); err != nil {
-               return err
-       }
-       c.sqldb = nil
-
        defer func() {
                c.cn = nil
        }()
diff --git a/go/adbc/driver/snowflake/snowflake_database.go 
b/go/adbc/driver/snowflake/snowflake_database.go
index 26f6b627c..a0548558f 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -21,7 +21,6 @@ import (
        "context"
        "crypto/rsa"
        "crypto/x509"
-       "database/sql"
        "encoding/pem"
        "errors"
        "fmt"
@@ -452,7 +451,6 @@ func (d *databaseImpl) Open(ctx context.Context) 
(adbc.Connection, error) {
        conn := &connectionImpl{
                cn: cn.(snowflakeConn),
                db: d, ctor: connector,
-               sqldb: sql.OpenDB(connector),
                // default enable high precision
                // SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled) 
to
                // get Int64/Float64 instead

Reply via email to