This is an automated email from the ASF dual-hosted git repository. zeroshade pushed a commit to branch fixup-metadata-getobjects-snowflake in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
commit 1ad0c098bfa8017e35a4dc62b1e9ef941857f109 Author: Matt Topol <[email protected]> AuthorDate: Wed Oct 16 13:04:20 2024 -0400 reduce duplication --- go/adbc/driver/snowflake/connection.go | 132 ++++++++++++++------------------- 1 file changed, 55 insertions(+), 77 deletions(-) diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 90ae67f4a..c77b45436 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -114,28 +114,47 @@ func getQueryID(ctx context.Context, query string, driverConn any) (string, erro const ( objSchemas = "SCHEMAS" objDatabases = "DATABASES" + objViews = "VIEWS" + objTables = "TABLES" + objObjects = "OBJECTS" ) -func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, objType string, catalog, dbSchema *string, outQueryID *string) { +func addLike(query string, pattern *string) string { + if pattern != nil && len(*pattern) > 0 && *pattern != "%" && *pattern != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*pattern) + "'" + } + return query +} + +func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, objType string, catalog, dbSchema, tableName *string, outQueryID *string) { grp.Go(func() error { return conn.Raw(func(driverConn any) (err error) { query := "SHOW TERSE /* ADBC:getObjects */ " + objType switch objType { case objDatabases: - if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" - } + query = addLike(query, catalog) query += " IN ACCOUNT" case objSchemas: - if dbSchema != nil && len(*dbSchema) > 0 && *dbSchema != "%" && *dbSchema != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*dbSchema) + "'" - } + query = addLike(query, dbSchema) if catalog == nil || isWildcardStr(*catalog) { query += " IN ACCOUNT" } else { query += " IN DATABASE " + quoteTblName(*catalog) } + case objViews, objTables, objObjects: + query = addLike(query, tableName) + + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + escapedCatalog := quoteTblName(*catalog) + if dbSchema == nil || isWildcardStr(*dbSchema) { + query += " IN DATABASE " + escapedCatalog + } else { + query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) + } + } default: return fmt.Errorf("unimplemented object type") } @@ -150,7 +169,7 @@ func isWildcardStr(ident string) bool { return strings.ContainsAny(ident, "_%") } -func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { +func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog, dbSchema, tableName, columnName *string, tableType []string) (array.RecordReader, error) { var ( pkQueryID, fkQueryID, uniqueQueryID, terseDbQueryID string showSchemaQueryID, tableQueryID string @@ -181,50 +200,29 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, case adbc.ObjectDepthCatalogs: queryFile = queryTemplateGetObjectsTerseCatalogs goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) + catalog, dbSchema, tableName, &terseDbQueryID) case adbc.ObjectDepthDBSchemas: queryFile = queryTemplateGetObjectsDbSchemas goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, - catalog, dbSchema, &showSchemaQueryID) + catalog, dbSchema, tableName, &showSchemaQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) + catalog, dbSchema, tableName, &terseDbQueryID) case adbc.ObjectDepthTables: queryFile = queryTemplateGetObjectsTables goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, - catalog, dbSchema, &showSchemaQueryID) + catalog, dbSchema, tableName, &showSchemaQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - objType := "objects" - if len(tableType) == 1 { - if strings.EqualFold("VIEW", tableType[0]) { - objType = "views" - } else if strings.EqualFold("TABLE", tableType[0]) { - objType = "tables" - } - } - - query := "SHOW TERSE /* ADBC:getObjectsTables */ " + objType - if tableName != nil && len(*tableName) > 0 && *tableName != "%" && *tableName != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*tableName) + "'" - } - if catalog == nil || isWildcardStr(*catalog) { - query += " IN ACCOUNT" - } else { - escapedCatalog := quoteTblName(*catalog) - if dbSchema == nil || isWildcardStr(*dbSchema) { - query += " IN DATABASE " + escapedCatalog - } else { - query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) - } - } - - tableQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) - // fallthrough + catalog, dbSchema, tableName, &terseDbQueryID) + objType := "objects" + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = "views" + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = "tables" + } + } + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, + catalog, dbSchema, tableName, &tableQueryID) default: var suffix string if catalog == nil || isWildcardStr(*catalog) { @@ -268,40 +266,20 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, }) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) + catalog, dbSchema, tableName, &terseDbQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, - catalog, dbSchema, &showSchemaQueryID) - - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - objType := "objects" - if len(tableType) == 1 { - if strings.EqualFold("VIEW", tableType[0]) { - objType = "views" - } else if strings.EqualFold("TABLE", tableType[0]) { - objType = "tables" - } - } - - query := "SHOW TERSE /* ADBC:getObjectsTables */ " + objType - if tableName != nil && len(*tableName) > 0 && *tableName != "%" && *tableName != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*tableName) + "'" - } - if catalog == nil || isWildcardStr(*catalog) { - query += " IN ACCOUNT" - } else { - escapedCatalog := quoteTblName(*catalog) - if dbSchema == nil || isWildcardStr(*dbSchema) { - query += " IN DATABASE " + escapedCatalog - } else { - query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) - } - } - - tableQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) + catalog, dbSchema, tableName, &showSchemaQueryID) + + objType := "objects" + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = "views" + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = "tables" + } + } + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, + catalog, dbSchema, tableName, &tableQueryID) } queryBytes, err := fs.ReadFile(queryTemplates, path.Join("queries", queryFile))
