This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch fixup-metadata-getobjects-snowflake
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git

commit 1ad0c098bfa8017e35a4dc62b1e9ef941857f109
Author: Matt Topol <[email protected]>
AuthorDate: Wed Oct 16 13:04:20 2024 -0400

    reduce duplication
---
 go/adbc/driver/snowflake/connection.go | 132 ++++++++++++++-------------------
 1 file changed, 55 insertions(+), 77 deletions(-)

diff --git a/go/adbc/driver/snowflake/connection.go 
b/go/adbc/driver/snowflake/connection.go
index 90ae67f4a..c77b45436 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -114,28 +114,47 @@ func getQueryID(ctx context.Context, query string, 
driverConn any) (string, erro
 const (
        objSchemas   = "SCHEMAS"
        objDatabases = "DATABASES"
+       objViews     = "VIEWS"
+       objTables    = "TABLES"
+       objObjects   = "OBJECTS"
 )
 
-func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, 
objType string, catalog, dbSchema *string, outQueryID *string) {
+func addLike(query string, pattern *string) string {
+       if pattern != nil && len(*pattern) > 0 && *pattern != "%" && *pattern 
!= ".*" {
+               query += " LIKE '" + escapeSingleQuoteForLike(*pattern) + "'"
+       }
+       return query
+}
+
+func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, 
objType string, catalog, dbSchema, tableName *string, outQueryID *string) {
        grp.Go(func() error {
                return conn.Raw(func(driverConn any) (err error) {
                        query := "SHOW TERSE /* ADBC:getObjects */ " + objType
                        switch objType {
                        case objDatabases:
-                               if catalog != nil && len(*catalog) > 0 && 
*catalog != "%" && *catalog != ".*" {
-                                       query += " LIKE '" + 
escapeSingleQuoteForLike(*catalog) + "'"
-                               }
+                               query = addLike(query, catalog)
                                query += " IN ACCOUNT"
                        case objSchemas:
-                               if dbSchema != nil && len(*dbSchema) > 0 && 
*dbSchema != "%" && *dbSchema != ".*" {
-                                       query += " LIKE '" + 
escapeSingleQuoteForLike(*dbSchema) + "'"
-                               }
+                               query = addLike(query, dbSchema)
 
                                if catalog == nil || isWildcardStr(*catalog) {
                                        query += " IN ACCOUNT"
                                } else {
                                        query += " IN DATABASE " + 
quoteTblName(*catalog)
                                }
+                       case objViews, objTables, objObjects:
+                               query = addLike(query, tableName)
+
+                               if catalog == nil || isWildcardStr(*catalog) {
+                                       query += " IN ACCOUNT"
+                               } else {
+                                       escapedCatalog := quoteTblName(*catalog)
+                                       if dbSchema == nil || 
isWildcardStr(*dbSchema) {
+                                               query += " IN DATABASE " + 
escapedCatalog
+                                       } else {
+                                               query += " IN SCHEMA " + 
escapedCatalog + "." + quoteTblName(*dbSchema)
+                                       }
+                               }
                        default:
                                return fmt.Errorf("unimplemented object type")
                        }
@@ -150,7 +169,7 @@ func isWildcardStr(ident string) bool {
        return strings.ContainsAny(ident, "_%")
 }
 
-func (c *connectionImpl) GetObjects(ctx context.Context, depth 
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, 
columnName *string, tableType []string) (array.RecordReader, error) {
+func (c *connectionImpl) GetObjects(ctx context.Context, depth 
adbc.ObjectDepth, catalog, dbSchema, tableName, columnName *string, tableType 
[]string) (array.RecordReader, error) {
        var (
                pkQueryID, fkQueryID, uniqueQueryID, terseDbQueryID string
                showSchemaQueryID, tableQueryID                     string
@@ -181,50 +200,29 @@ func (c *connectionImpl) GetObjects(ctx context.Context, 
depth adbc.ObjectDepth,
        case adbc.ObjectDepthCatalogs:
                queryFile = queryTemplateGetObjectsTerseCatalogs
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
-                       catalog, dbSchema, &terseDbQueryID)
+                       catalog, dbSchema, tableName, &terseDbQueryID)
        case adbc.ObjectDepthDBSchemas:
                queryFile = queryTemplateGetObjectsDbSchemas
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas,
-                       catalog, dbSchema, &showSchemaQueryID)
+                       catalog, dbSchema, tableName, &showSchemaQueryID)
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
-                       catalog, dbSchema, &terseDbQueryID)
+                       catalog, dbSchema, tableName, &terseDbQueryID)
        case adbc.ObjectDepthTables:
                queryFile = queryTemplateGetObjectsTables
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas,
-                       catalog, dbSchema, &showSchemaQueryID)
+                       catalog, dbSchema, tableName, &showSchemaQueryID)
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
-                       catalog, dbSchema, &terseDbQueryID)
-               gQueryIDs.Go(func() error {
-                       return conn.Raw(func(driverConn any) (err error) {
-                               objType := "objects"
-                               if len(tableType) == 1 {
-                                       if strings.EqualFold("VIEW", 
tableType[0]) {
-                                               objType = "views"
-                                       } else if strings.EqualFold("TABLE", 
tableType[0]) {
-                                               objType = "tables"
-                                       }
-                               }
-
-                               query := "SHOW TERSE /* ADBC:getObjectsTables 
*/ " + objType
-                               if tableName != nil && len(*tableName) > 0 && 
*tableName != "%" && *tableName != ".*" {
-                                       query += " LIKE '" + 
escapeSingleQuoteForLike(*tableName) + "'"
-                               }
-                               if catalog == nil || isWildcardStr(*catalog) {
-                                       query += " IN ACCOUNT"
-                               } else {
-                                       escapedCatalog := quoteTblName(*catalog)
-                                       if dbSchema == nil || 
isWildcardStr(*dbSchema) {
-                                               query += " IN DATABASE " + 
escapedCatalog
-                                       } else {
-                                               query += " IN SCHEMA " + 
escapedCatalog + "." + quoteTblName(*dbSchema)
-                                       }
-                               }
-
-                               tableQueryID, err = getQueryID(gQueryIDsCtx, 
query, driverConn)
-                               return
-                       })
-               })
-               // fallthrough
+                       catalog, dbSchema, tableName, &terseDbQueryID)
+               objType := "objects"
+               if len(tableType) == 1 {
+                       if strings.EqualFold("VIEW", tableType[0]) {
+                               objType = "views"
+                       } else if strings.EqualFold("TABLE", tableType[0]) {
+                               objType = "tables"
+                       }
+               }
+               goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType,
+                       catalog, dbSchema, tableName, &tableQueryID)
        default:
                var suffix string
                if catalog == nil || isWildcardStr(*catalog) {
@@ -268,40 +266,20 @@ func (c *connectionImpl) GetObjects(ctx context.Context, 
depth adbc.ObjectDepth,
                })
 
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
-                       catalog, dbSchema, &terseDbQueryID)
+                       catalog, dbSchema, tableName, &terseDbQueryID)
                goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas,
-                       catalog, dbSchema, &showSchemaQueryID)
-
-               gQueryIDs.Go(func() error {
-                       return conn.Raw(func(driverConn any) (err error) {
-                               objType := "objects"
-                               if len(tableType) == 1 {
-                                       if strings.EqualFold("VIEW", 
tableType[0]) {
-                                               objType = "views"
-                                       } else if strings.EqualFold("TABLE", 
tableType[0]) {
-                                               objType = "tables"
-                                       }
-                               }
-
-                               query := "SHOW TERSE /* ADBC:getObjectsTables 
*/ " + objType
-                               if tableName != nil && len(*tableName) > 0 && 
*tableName != "%" && *tableName != ".*" {
-                                       query += " LIKE '" + 
escapeSingleQuoteForLike(*tableName) + "'"
-                               }
-                               if catalog == nil || isWildcardStr(*catalog) {
-                                       query += " IN ACCOUNT"
-                               } else {
-                                       escapedCatalog := quoteTblName(*catalog)
-                                       if dbSchema == nil || 
isWildcardStr(*dbSchema) {
-                                               query += " IN DATABASE " + 
escapedCatalog
-                                       } else {
-                                               query += " IN SCHEMA " + 
escapedCatalog + "." + quoteTblName(*dbSchema)
-                                       }
-                               }
-
-                               tableQueryID, err = getQueryID(gQueryIDsCtx, 
query, driverConn)
-                               return
-                       })
-               })
+                       catalog, dbSchema, tableName, &showSchemaQueryID)
+
+               objType := "objects"
+               if len(tableType) == 1 {
+                       if strings.EqualFold("VIEW", tableType[0]) {
+                               objType = "views"
+                       } else if strings.EqualFold("TABLE", tableType[0]) {
+                               objType = "tables"
+                       }
+               }
+               goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType,
+                       catalog, dbSchema, tableName, &tableQueryID)
        }
 
        queryBytes, err := fs.ReadFile(queryTemplates, path.Join("queries", 
queryFile))

Reply via email to