This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 64c19ebf fix(go/adbc/driver/snowflake): Removing SQL injection to get
table name with special character for getObjectsTables (#1338)
64c19ebf is described below
commit 64c19ebf7519dc285aa703412469b3634a45651a
Author: AnithaPanduranganMS
<[email protected]>
AuthorDate: Wed Jan 3 09:22:56 2024 -0800
fix(go/adbc/driver/snowflake): Removing SQL injection to get table name
with special character for getObjectsTables (#1338)
**Description:**
> GetObjects API was inconsistent getting table with special character
and making the conditions case-insensitive.
**Solution:**
> Passing table names as query argument and avoiding SQL Injection
**Testing:**
> Added test in DriverTest
Fixes #1225
---------
Co-authored-by: Anitha <[email protected]>
Co-authored-by: David Li <[email protected]>
---
.../test/Drivers/Interop/Snowflake/DriverTests.cs | 105 ++++-
go/adbc/driver/flightsql/flightsql_connection.go | 4 +-
go/adbc/driver/internal/shared_utils.go | 21 +-
go/adbc/driver/snowflake/connection.go | 508 +++++++++++++--------
go/adbc/driver/snowflake/connection_test.go | 320 +++++++++++++
5 files changed, 756 insertions(+), 202 deletions(-)
diff --git a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
index ea3574fc..55535be6 100644
--- a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
@@ -1,4 +1,4 @@
-/*
+/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@@ -16,6 +16,7 @@
*/
using System;
+using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Apache.Arrow.Adbc.Tests.Metadata;
@@ -110,10 +111,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
for (int i = 0; i < queries.Length; i++)
{
string query = queries[i];
- using AdbcStatement statement = _connection.CreateStatement();
- statement.SqlQuery = query;
-
- UpdateResult updateResult = statement.ExecuteUpdate();
+ UpdateResult updateResult = ExecuteUpdateStatement(query);
Assert.Equal(expectedResults[i], updateResult.AffectedRows);
}
@@ -279,17 +277,66 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
IEnumerable<AdbcColumn> highPrecisionColumns = columns.Where(c
=> c.XdbcTypeName == "NUMBER");
- if(highPrecisionColumns.Count() > 0)
+ if (highPrecisionColumns.Count() > 0)
{
// ensure they all are coming back as
XdbcDataType_XDBC_DECIMAL because they are Decimal128
short XdbcDataType_XDBC_DECIMAL = 3;
- IEnumerable<AdbcColumn> invalidHighPrecisionColumns =
highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
+ IEnumerable<AdbcColumn> invalidHighPrecisionColumns =
highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
int count = invalidHighPrecisionColumns.Count();
Assert.True(count == 0, $"There are {count} columns that
do not map to the correct XdbcSqlDataType when UseHighPrecision=true");
}
}
}
+ /// <summary>
+ /// Validates if the driver can call GetObjects with GetObjectsDepth
as Tables with TableName as a Special Character.
+ /// </summary>
+ [SkippableTheory, Order(3)]
+ [InlineData(@"ADBCDEMO_DB",@"PUBLIC","MyIdentifier")]
+ [InlineData(@"ADBCDEMO'DB", @"PUBLIC'SCHEMA","my.identifier")]
+ [InlineData(@"ADBCDEM""DB", @"PUBLIC""SCHEMA", "my.identifier")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "my identifier")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My 'Identifier'")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "3rd_identifier")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "$Identifier")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My ^Identifier")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My ^Ident~ifier")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"My\^Ident~ifier")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "идентификатор")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"ADBCTest_""ALL""TYPES")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"ADBC\TEST""\TAB_""LE")]
+ [InlineData(@"ADBCDEMO_DB", @"PUBLIC", "ONE")]
+ public void CanGetObjectsTablesWithSpecialCharacter(string
databaseName, string schemaName, string tableName)
+ {
+ CreateDatabaseAndTable(databaseName, schemaName, tableName);
+
+ using IArrowArrayStream stream = _connection.GetObjects(
+ depth: AdbcConnection.GetObjectsDepth.Tables,
+ catalogPattern: databaseName,
+ dbSchemaPattern: schemaName,
+ tableNamePattern: tableName,
+ tableTypes: new List<string> { "BASE TABLE", "VIEW" },
+ columnNamePattern: null);
+
+ using RecordBatch recordBatch =
stream.ReadNextRecordBatchAsync().Result;
+
+ List<AdbcCatalog> catalogs =
GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);
+
+ List<AdbcTable> tables = catalogs
+ .Where(c => string.Equals(c.Name, databaseName))
+ .Select(c => c.DbSchemas)
+ .FirstOrDefault()
+ .Where(s => string.Equals(s.Name, schemaName))
+ .Select(s => s.Tables)
+ .FirstOrDefault();
+
+ AdbcTable table = tables.FirstOrDefault();
+
+ Assert.True(table != null, "table should not be null");
+ Assert.Equal(tableName, table.Name, true);
+ DropDatabaseAndTable(databaseName, schemaName, tableName);
+ }
+
/// <summary>
/// Validates if the driver can call GetTableSchema.
/// </summary>
@@ -354,6 +401,50 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
Tests.DriverTests.CanExecuteQuery(queryResult,
_testConfiguration.ExpectedResultsCount);
}
+ private void CreateDatabaseAndTable(string databaseName, string
schemaName, string tableName)
+ {
+ databaseName = databaseName.Replace("\"", "\"\"");
+ schemaName = schemaName.Replace("\"", "\"\"");
+ tableName = tableName.Replace("\"", "\"\"");
+
+ string createDatabase = string.Format("CREATE DATABASE IF NOT
EXISTS \"{0}\"", databaseName);
+ ExecuteUpdateStatement(createDatabase);
+
+ string createSchema = string.Format("CREATE SCHEMA IF NOT EXISTS
\"{0}\".\"{1}\"", databaseName, schemaName);
+ ExecuteUpdateStatement(createSchema);
+
+ string fullyQualifiedTableName =
string.Format("\"{0}\".\"{1}\".\"{2}\"", databaseName, schemaName, tableName);
+ string createTableStatement = string.Format("CREATE OR REPLACE
TABLE {0} (INDEX INT)", fullyQualifiedTableName);
+ ExecuteUpdateStatement(createTableStatement);
+
+ }
+
+ private void DropDatabaseAndTable(string databaseName, string
schemaName, string tableName)
+ {
+ tableName = tableName.Replace("\"", "\"\"");
+ schemaName = schemaName.Replace("\"", "\"\"");
+ databaseName = databaseName.Replace("\"", "\"\"");
+
+ string fullyQualifiedTableName =
string.Format("\"{0}\".\"{1}\".\"{2}\"", databaseName, schemaName, tableName);
+ string createTableStatement = string.Format("DROP TABLE IF EXISTS
{0} ", fullyQualifiedTableName);
+ ExecuteUpdateStatement(createTableStatement);
+
+ string createSchema = string.Format("DROP SCHEMA IF EXISTS
\"{0}\".\"{1}\"", databaseName, schemaName);
+ ExecuteUpdateStatement(createSchema);
+
+ string createDatabase = string.Format("DROP DATABASE IF EXISTS
\"{0}\"", databaseName);
+ ExecuteUpdateStatement(createDatabase);
+
+ }
+
+ private UpdateResult ExecuteUpdateStatement(string query)
+ {
+ using AdbcStatement statement = _connection.CreateStatement();
+ statement.SqlQuery = query;
+ UpdateResult updateResult = statement.ExecuteUpdate();
+ return updateResult;
+ }
+
private static string GetPartialNameForPatternMatch(string name)
{
if (string.IsNullOrEmpty(name) || name.Length == 1) return name;
diff --git a/go/adbc/driver/flightsql/flightsql_connection.go
b/go/adbc/driver/flightsql/flightsql_connection.go
index a541bfeb..ee24b75e 100644
--- a/go/adbc/driver/flightsql/flightsql_connection.go
+++ b/go/adbc/driver/flightsql/flightsql_connection.go
@@ -547,7 +547,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema
*arrow.Schema, info
}
// Helper function to build up a map of catalogs to DB schemas
-func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string) (result
map[string][]string, err error) {
+func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords
[]internal.Metadata) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
@@ -588,7 +588,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context,
depth adbc.ObjectDepth,
return
}
-func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string) (result internal.SchemaToTableInfo, err error) {
+func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string, metadataRecords []internal.Metadata) (result
internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth ==
adbc.ObjectDepthDBSchemas {
return
}
diff --git a/go/adbc/driver/internal/shared_utils.go
b/go/adbc/driver/internal/shared_utils.go
index a6b285f0..3a579119 100644
--- a/go/adbc/driver/internal/shared_utils.go
+++ b/go/adbc/driver/internal/shared_utils.go
@@ -19,9 +19,11 @@ package internal
import (
"context"
+ "database/sql"
"regexp"
"strconv"
"strings"
+ "time"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v14/arrow"
@@ -38,8 +40,18 @@ type TableInfo struct {
Schema *arrow.Schema
}
-type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, schema *string) (map[string][]string, error)
-type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, schema *string, tableName *string, columnName *string, tableType
[]string) (map[CatalogAndSchema][]TableInfo, error)
+type Metadata struct {
+ Created
time.Time
+ ColName, DataType
string
+ Dbname, Kind, Schema, TblName, TblType, IdentGen, IdentIncrement,
Comment sql.NullString
+ OrdinalPos
int
+ NumericPrec, NumericPrecRadix, NumericScale, DatetimePrec
sql.NullInt16
+ IsNullable, IsIdent
bool
+ CharMaxLength, CharOctetLength
sql.NullInt32
+}
+
+type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, schema *string, metadataRecords []Metadata)
(map[string][]string, error)
+type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, schema *string, tableName *string, columnName *string, tableType
[]string, metadataRecords []Metadata) (map[CatalogAndSchema][]TableInfo, error)
type SchemaToTableInfo = map[CatalogAndSchema][]TableInfo
// Helper function that compiles a SQL-style pattern (%, _) to a regex
@@ -87,6 +99,7 @@ type GetObjects struct {
builder *array.RecordBuilder
schemaLookup map[string][]string
tableLookup map[CatalogAndSchema][]TableInfo
+ MetadataRecords []Metadata
catalogPattern *regexp.Regexp
columnNamePattern *regexp.Regexp
@@ -123,13 +136,13 @@ type GetObjects struct {
}
func (g *GetObjects) Init(mem memory.Allocator, getObj GetObjDBSchemasFn,
getTbls GetObjTablesFn) error {
- if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog,
g.DbSchema); err != nil {
+ if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog,
g.DbSchema, g.MetadataRecords); err != nil {
return err
} else {
g.schemaLookup = catalogToDbSchemas
}
- if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema,
g.TableName, g.ColumnName, g.TableType); err != nil {
+ if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema,
g.TableName, g.ColumnName, g.TableType, g.MetadataRecords); err != nil {
return err
} else {
g.tableLookup = tableLookup
diff --git a/go/adbc/driver/snowflake/connection.go
b/go/adbc/driver/snowflake/connection.go
index e9679993..73c31604 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -23,6 +23,7 @@ import (
"database/sql/driver"
"fmt"
"io"
+ "regexp"
"strconv"
"strings"
"time"
@@ -238,84 +239,59 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) (array.Re
// All non-empty, non-nil strings should be a search pattern (as described
// earlier).
func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, dbSchema *string, tableName *string, columnName *string, tableType
[]string) (array.RecordReader, error) {
- g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog,
DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType:
tableType}
- if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
+ metadataRecords, err := c.populateMetadata(ctx, depth, catalog,
dbSchema, tableName, columnName, tableType)
+ if err != nil {
return nil, err
}
- defer g.Release()
- rows, err := c.sqldb.QueryContext(ctx, "SHOW TERSE DATABASES", nil)
- if err != nil {
+ g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog,
DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType:
tableType}
+ g.MetadataRecords = metadataRecords
+ if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
return nil, err
}
- defer rows.Close()
-
- var (
- created time.Time
- name string
- kind, dbname, schema sql.NullString
- )
- for rows.Next() {
- if err := rows.Scan(&created, &name, &kind, &dbname, &schema);
err != nil {
- return nil, errToAdbcErr(adbc.StatusInvalidData, err)
- }
+ defer g.Release()
- // SNOWFLAKE catalog contains functions and no tables
- if name == "SNOWFLAKE" {
+ uniqueCatalogs := make(map[string]bool)
+ for _, data := range metadataRecords {
+ if !data.Dbname.Valid {
continue
}
- // schema for SHOW TERSE DATABASES is:
- // created_on:timestamp, name:text, kind:null,
database_name:null, schema_name:null
- // the last three columns are always null because they are not
applicable for databases
- // so we want values[1].(string) for the name
- g.AppendCatalog(name)
+ if _, exists := uniqueCatalogs[data.Dbname.String]; !exists {
+ uniqueCatalogs[data.Dbname.String] = true
+ g.AppendCatalog(data.Dbname.String)
+ }
}
return g.Finish()
}
-func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string) (result
map[string][]string, err error) {
+func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords
[]internal.Metadata) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
- conditions := make([]string, 0)
- if catalog != nil && *catalog != "" {
- conditions = append(conditions, ` CATALOG_NAME ILIKE
'`+*catalog+`'`)
- }
- if dbSchema != nil && *dbSchema != "" {
- conditions = append(conditions, ` SCHEMA_NAME ILIKE
'`+*dbSchema+`'`)
- }
-
- cond := strings.Join(conditions, " AND ")
-
result = make(map[string][]string)
+ uniqueCatalogSchema := make(map[string]map[string]bool)
- query := `SELECT CATALOG_NAME, SCHEMA_NAME FROM
INFORMATION_SCHEMA.SCHEMATA`
- if cond != "" {
- query += " WHERE " + cond
- }
- var rows *sql.Rows
- rows, err = c.sqldb.QueryContext(ctx, query)
- if err != nil {
- err = errToAdbcErr(adbc.StatusIO, err)
- return
- }
- defer rows.Close()
+ for _, data := range metadataRecords {
+ if !data.Dbname.Valid || !data.Schema.Valid {
+ continue
+ }
- var catalogName, schemaName string
- for rows.Next() {
- if err = rows.Scan(&catalogName, &schemaName); err != nil {
- err = errToAdbcErr(adbc.StatusIO, err)
- return
+ if _, exists := uniqueCatalogSchema[data.Dbname.String];
!exists {
+ uniqueCatalogSchema[data.Dbname.String] =
make(map[string]bool)
}
- cat, ok := result[catalogName]
- if !ok {
+ cat, exists := result[data.Dbname.String]
+ if !exists {
cat = make([]string, 0, 1)
}
- result[catalogName] = append(cat, schemaName)
+
+ if _, exists :=
uniqueCatalogSchema[data.Dbname.String][data.Schema.String]; !exists {
+ result[data.Dbname.String] = append(cat,
data.Schema.String)
+
uniqueCatalogSchema[data.Dbname.String][data.Schema.String] = true
+ }
}
return
@@ -477,7 +453,7 @@ func toXdbcDataType(dt arrow.DataType) (xdbcType
internal.XdbcDataType) {
}
}
-func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string) (result internal.SchemaToTableInfo, err error) {
+func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string, metadataRecords []internal.Metadata) (result
internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth ==
adbc.ObjectDepthDBSchemas {
return
}
@@ -485,152 +461,45 @@ func (c *cnxn) getObjectsTables(ctx context.Context,
depth adbc.ObjectDepth, cat
result = make(internal.SchemaToTableInfo)
includeSchema := depth == adbc.ObjectDepthAll || depth ==
adbc.ObjectDepthColumns
- conditions := make([]string, 0)
- if catalog != nil && *catalog != "" {
- conditions = append(conditions, ` TABLE_CATALOG ILIKE
'`+*catalog+`'`)
- }
- if dbSchema != nil && *dbSchema != "" {
- conditions = append(conditions, ` TABLE_SCHEMA ILIKE
'`+*dbSchema+`'`)
- }
- if tableName != nil && *tableName != "" {
- conditions = append(conditions, ` TABLE_NAME ILIKE
'`+*tableName+`'`)
- }
-
- // first populate the tables and table types
- var rows *sql.Rows
- var tblConditions []string
- if len(tableType) > 0 {
- tblConditions = append(conditions, ` TABLE_TYPE IN
('`+strings.Join(tableType, `','`)+`')`)
- } else {
- tblConditions = conditions
- }
+ uniqueCatalogSchemaTable := make(map[string]map[string]map[string]bool)
+ for _, data := range metadataRecords {
+ if !data.Dbname.Valid || !data.Schema.Valid ||
!data.TblName.Valid || !data.TblType.Valid {
+ continue
+ }
- cond := strings.Join(tblConditions, " AND ")
- query := "SELECT table_catalog, table_schema, table_name, table_type
FROM INFORMATION_SCHEMA.TABLES"
- if cond != "" {
- query += " WHERE " + cond
- }
- rows, err = c.sqldb.QueryContext(ctx, query)
- if err != nil {
- err = errToAdbcErr(adbc.StatusIO, err)
- return
- }
- defer rows.Close()
+ if _, exists := uniqueCatalogSchemaTable[data.Dbname.String];
!exists {
+ uniqueCatalogSchemaTable[data.Dbname.String] =
make(map[string]map[string]bool)
+ }
- var tblCat, tblSchema, tblName string
- var tblType sql.NullString
- for rows.Next() {
- if err = rows.Scan(&tblCat, &tblSchema, &tblName, &tblType);
err != nil {
- err = errToAdbcErr(adbc.StatusIO, err)
- return
+ if _, exists :=
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String]; !exists {
+
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String] =
make(map[string]bool)
}
- key := internal.CatalogAndSchema{
- Catalog: tblCat, Schema: tblSchema}
+ if _, exists :=
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String][data.TblName.String];
!exists {
+
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String][data.TblName.String]
= true
- result[key] = append(result[key], internal.TableInfo{
- Name: tblName, TableType: tblType.String})
- }
+ key := internal.CatalogAndSchema{
+ Catalog: data.Dbname.String, Schema:
data.Schema.String}
- if includeSchema {
- conditions := make([]string, 0)
- if catalog != nil && *catalog != "" {
- conditions = append(conditions, ` TABLE_CATALOG ILIKE
\'`+*catalog+`\'`)
- }
- if dbSchema != nil && *dbSchema != "" {
- conditions = append(conditions, ` TABLE_SCHEMA ILIKE
\'`+*dbSchema+`\'`)
+ result[key] = append(result[key], internal.TableInfo{
+ Name: data.TblName.String, TableType:
data.TblType.String})
}
- if tableName != nil && *tableName != "" {
- conditions = append(conditions, ` TABLE_NAME ILIKE
\'`+*tableName+`\'`)
- }
- // if we need to include the schemas of the tables, make
another fetch
- // to fetch the columns and column info
- if columnName != nil && *columnName != "" {
- conditions = append(conditions, ` column_name ILIKE
\'`+*columnName+`\'`)
- }
- cond = strings.Join(conditions, " AND ")
- if cond != "" {
- cond = " WHERE " + cond
- }
- cond = `statement := 'SELECT * FROM (' || statement || ')` +
cond +
- ` ORDER BY table_catalog, table_schema, table_name,
ordinal_position';`
-
- var queryPrefix = `DECLARE
- c1 CURSOR FOR SELECT DATABASE_NAME FROM
INFORMATION_SCHEMA.DATABASES;
- res RESULTSET;
- counter INTEGER DEFAULT 0;
- statement VARCHAR DEFAULT '';
- BEGIN
- FOR rec IN c1 DO
- IF (counter > 0) THEN
- statement := statement || ' UNION ALL ';
- END IF;
- `
-
- const getSchema = `statement := statement ||
- ' SELECT
- table_catalog, table_schema,
table_name, column_name,
- ordinal_position, is_nullable::boolean,
data_type, numeric_precision,
- numeric_precision_radix, numeric_scale,
is_identity::boolean,
- identity_generation, identity_increment,
- character_maximum_length,
character_octet_length, datetime_precision, comment
- FROM ' || rec.database_name ||
'.INFORMATION_SCHEMA.COLUMNS';
-
- counter := counter + 1;
- END FOR;
- `
-
- const querySuffix = `
- res := (EXECUTE IMMEDIATE :statement);
- RETURN TABLE (res);
- END;`
-
- if catalog != nil && *catalog != "" {
- queryPrefix = `DECLARE
- c1 CURSOR FOR SELECT DATABASE_NAME FROM
INFORMATION_SCHEMA.DATABASES WHERE DATABASE_NAME ILIKE '` + *catalog + `';` +
- `res RESULTSET;
- counter INTEGER DEFAULT 0;
- statement VARCHAR DEFAULT '';
- BEGIN
- FOR rec IN c1 DO
- IF (counter > 0) THEN
- statement := statement || '
UNION ALL ';
- END IF;
- `
- }
- query = queryPrefix + getSchema + cond + querySuffix
- rows, err = c.sqldb.QueryContext(ctx, query)
- if err != nil {
- return
- }
- defer rows.Close()
+ }
+ if includeSchema {
var (
- colName, dataType
string
- identGen, identIncrement, comment
sql.NullString
- ordinalPos
int
- numericPrec, numericPrecRadix, numericScale,
datetimePrec sql.NullInt16
- isNullable, isIdent
bool
- charMaxLength, charOctetLength
sql.NullInt32
-
prevKey internal.CatalogAndSchema
curTableInfo *internal.TableInfo
fieldList = make([]arrow.Field, 0)
)
- for rows.Next() {
- // order here matches the order of the columns
requested in the query
- err = rows.Scan(&tblCat, &tblSchema, &tblName, &colName,
- &ordinalPos, &isNullable, &dataType,
&numericPrec,
- &numericPrecRadix, &numericScale, &isIdent,
&identGen,
- &identIncrement, &charMaxLength,
&charOctetLength, &datetimePrec, &comment)
- if err != nil {
- err = errToAdbcErr(adbc.StatusIO, err)
- return
+ for _, data := range metadataRecords {
+ if !data.Dbname.Valid || !data.Schema.Valid ||
!data.TblName.Valid {
+ continue
}
- key := internal.CatalogAndSchema{Catalog: tblCat,
Schema: tblSchema}
- if prevKey != key || (curTableInfo != nil &&
curTableInfo.Name != tblName) {
+ key := internal.CatalogAndSchema{Catalog:
data.Dbname.String, Schema: data.Schema.String}
+ if prevKey != key || (curTableInfo != nil &&
curTableInfo.Name != data.TblName.String) {
if len(fieldList) > 0 && curTableInfo != nil {
curTableInfo.Schema =
arrow.NewSchema(fieldList, nil)
fieldList = fieldList[:0]
@@ -638,7 +507,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth
adbc.ObjectDepth, cat
info := result[key]
for i := range info {
- if info[i].Name == tblName {
+ if info[i].Name == data.TblName.String {
curTableInfo = &info[i]
break
}
@@ -646,7 +515,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth
adbc.ObjectDepth, cat
}
prevKey = key
- fieldList = append(fieldList, toField(colName,
isNullable, dataType, numericPrec, numericPrecRadix, numericScale, isIdent,
c.useHighPrecision, identGen, identIncrement, charMaxLength, charOctetLength,
datetimePrec, comment, ordinalPos))
+ fieldList = append(fieldList, toField(data.ColName,
data.IsNullable, data.DataType, data.NumericPrec, data.NumericPrecRadix,
data.NumericScale, data.IsIdent, c.useHighPrecision, data.IdentGen,
data.IdentIncrement, data.CharMaxLength, data.CharOctetLength,
data.DatetimePrec, data.Comment, data.OrdinalPos))
}
if len(fieldList) > 0 && curTableInfo != nil {
@@ -656,6 +525,267 @@ func (c *cnxn) getObjectsTables(ctx context.Context,
depth adbc.ObjectDepth, cat
return
}
+func (c *cnxn) populateMetadata(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string) ([]internal.Metadata, error) {
+ var metadataRecords []internal.Metadata
+ catalogMetadataRecords, err := c.getCatalogsMetadata(ctx)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+
+ matchingCatalogNames, err :=
getMatchingCatalogNames(catalogMetadataRecords, catalog)
+ if err != nil {
+ return nil, adbc.Error{
+ Msg: err.Error(),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ if depth == adbc.ObjectDepthCatalogs {
+ metadataRecords = catalogMetadataRecords
+ } else if depth == adbc.ObjectDepthDBSchemas {
+ metadataRecords, err = c.getDbSchemasMetadata(ctx,
matchingCatalogNames, catalog, dbSchema)
+ } else if depth == adbc.ObjectDepthTables {
+ metadataRecords, err = c.getTablesMetadata(ctx,
matchingCatalogNames, catalog, dbSchema, tableName, tableType)
+ } else {
+ metadataRecords, err = c.getColumnsMetadata(ctx,
matchingCatalogNames, catalog, dbSchema, tableName, columnName, tableType)
+ }
+
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+
+ return metadataRecords, nil
+}
+
+func (c *cnxn) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata,
error) {
+ metadataRecords := make([]internal.Metadata, 0)
+
+ rows, err := c.sqldb.QueryContext(ctx, prepareCatalogsSQL(), nil)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+
+ for rows.Next() {
+ var data internal.Metadata
+ var skipDbNullField, skipSchemaNullField sql.NullString
+ // schema for SHOW TERSE DATABASES is:
+ // created_on:timestamp, name:text, kind:null,
database_name:null, schema_name:null
+ // the last three columns are always null because they are not
applicable for databases
+ // so we want values[1].(string) for the name
+ if err := rows.Scan(&data.Created, &data.Dbname, &data.Kind,
&skipDbNullField, &skipSchemaNullField); err != nil {
+ return nil, errToAdbcErr(adbc.StatusInvalidData, err)
+ }
+
+ // SNOWFLAKE catalog contains functions and no tables
+ if data.Dbname.Valid && data.Dbname.String == "SNOWFLAKE" {
+ continue
+ }
+
+ metadataRecords = append(metadataRecords, data)
+ }
+ return metadataRecords, nil
+}
+
+func (c *cnxn) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames
[]string, catalog *string, dbSchema *string) ([]internal.Metadata, error) {
+ var metadataRecords []internal.Metadata
+ query, queryArgs := prepareDbSchemasSQL(matchingCatalogNames, catalog,
dbSchema)
+ rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var data internal.Metadata
+ if err = rows.Scan(&data.Dbname, &data.Schema); err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ metadataRecords = append(metadataRecords, data)
+ }
+ return metadataRecords, nil
+}
+
+func (c *cnxn) getTablesMetadata(ctx context.Context, matchingCatalogNames
[]string, catalog *string, dbSchema *string, tableName *string, tableType
[]string) ([]internal.Metadata, error) {
+ metadataRecords := make([]internal.Metadata, 0)
+ query, queryArgs := prepareTablesSQL(matchingCatalogNames, catalog,
dbSchema, tableName, tableType)
+ rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var data internal.Metadata
+ if err = rows.Scan(&data.Dbname, &data.Schema, &data.TblName,
&data.TblType); err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ metadataRecords = append(metadataRecords, data)
+ }
+ return metadataRecords, nil
+}
+
+func (c *cnxn) getColumnsMetadata(ctx context.Context, matchingCatalogNames
[]string, catalog *string, dbSchema *string, tableName *string, columnName
*string, tableType []string) ([]internal.Metadata, error) {
+ metadataRecords := make([]internal.Metadata, 0)
+ query, queryArgs := prepareColumnsSQL(matchingCatalogNames, catalog,
dbSchema, tableName, columnName, tableType)
+ rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ defer rows.Close()
+
+ var data internal.Metadata
+
+ for rows.Next() {
+ // order here matches the order of the columns requested in the
query
+ err = rows.Scan(&data.TblType, &data.Dbname, &data.Schema,
&data.TblName, &data.ColName,
+ &data.OrdinalPos, &data.IsNullable, &data.DataType,
&data.NumericPrec,
+ &data.NumericPrecRadix, &data.NumericScale,
&data.IsIdent, &data.IdentGen,
+ &data.IdentIncrement, &data.CharMaxLength,
&data.CharOctetLength, &data.DatetimePrec, &data.Comment)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ metadataRecords = append(metadataRecords, data)
+ }
+ return metadataRecords, nil
+}
+
+func getMatchingCatalogNames(metadataRecords []internal.Metadata, catalog
*string) ([]string, error) {
+ matchingCatalogNames := make([]string, 0)
+ var catalogPattern *regexp.Regexp
+ var err error
+ if catalogPattern, err = internal.PatternToRegexp(catalog); err != nil {
+ return nil, adbc.Error{
+ Msg: err.Error(),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ for _, data := range metadataRecords {
+ if data.Dbname.Valid && data.Dbname.String == "SNOWFLAKE" {
+ continue
+ }
+ if catalogPattern != nil &&
!catalogPattern.MatchString(data.Dbname.String) {
+ continue
+ }
+
+ matchingCatalogNames = append(matchingCatalogNames,
data.Dbname.String)
+ }
+ return matchingCatalogNames, nil
+}
+
+func prepareCatalogsSQL() string {
+ return "SHOW TERSE DATABASES"
+}
+
+func prepareDbSchemasSQL(matchingCatalogNames []string, catalog *string,
dbSchema *string) (string, []interface{}) {
+ query := ""
+ for _, catalog_name := range matchingCatalogNames {
+ if query != "" {
+ query += " UNION ALL "
+ }
+ query += `SELECT * FROM "` + strings.ReplaceAll(catalog_name,
"\"", "\"\"") + `".INFORMATION_SCHEMA.SCHEMATA`
+ }
+
+ query = `SELECT CATALOG_NAME, SCHEMA_NAME FROM (` + query + `)`
+ conditions, queryArgs :=
prepareFilterConditions(adbc.ObjectDepthDBSchemas, catalog, dbSchema, nil, nil,
make([]string, 0))
+ if conditions != "" {
+ query += " WHERE " + conditions
+ }
+
+ return query, queryArgs
+}
+
+func prepareTablesSQL(matchingCatalogNames []string, catalog *string, dbSchema
*string, tableName *string, tableType []string) (string, []interface{}) {
+ query := ""
+ for _, catalog_name := range matchingCatalogNames {
+ if query != "" {
+ query += " UNION ALL "
+ }
+ query += `SELECT * FROM "` + strings.ReplaceAll(catalog_name,
"\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES`
+ }
+
+ query = `SELECT table_catalog, table_schema, table_name, table_type
FROM (` + query + `)`
+ conditions, queryArgs :=
prepareFilterConditions(adbc.ObjectDepthTables, catalog, dbSchema, tableName,
nil, tableType)
+ if conditions != "" {
+ query += " WHERE " + conditions
+ }
+ return query, queryArgs
+}
+
+func prepareColumnsSQL(matchingCatalogNames []string, catalog *string,
dbSchema *string, tableName *string, columnName *string, tableType []string)
(string, []interface{}) {
+ prefixQuery := ""
+ for _, catalog_name := range matchingCatalogNames {
+ if prefixQuery != "" {
+ prefixQuery += " UNION ALL "
+ }
+ prefixQuery += `SELECT T.table_type,
+ C.*
+ FROM
+ "` + strings.ReplaceAll(catalog_name, "\"",
"\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
+ JOIN
+ "` + strings.ReplaceAll(catalog_name, "\"",
"\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
+ ON
+ T.table_catalog = C.table_catalog
+ AND T.table_schema = C.table_schema
+ AND t.table_name = C.table_name`
+ }
+
+ prefixQuery = `SELECT table_type, table_catalog, table_schema,
table_name, column_name,
+ ordinal_position,
is_nullable::boolean, data_type, numeric_precision,
+ numeric_precision_radix,
numeric_scale, is_identity::boolean,
+ identity_generation,
identity_increment,
+ character_maximum_length,
character_octet_length, datetime_precision, comment FROM (` + prefixQuery + `)`
+ ordering := ` ORDER BY table_catalog, table_schema, table_name,
ordinal_position`
+ conditions, queryArgs :=
prepareFilterConditions(adbc.ObjectDepthColumns, catalog, dbSchema, tableName,
columnName, tableType)
+ query := prefixQuery
+
+ if conditions != "" {
+ query += " WHERE " + conditions
+ }
+
+ query += ordering
+ return query, queryArgs
+}
+
+func prepareFilterConditions(depth adbc.ObjectDepth, catalog *string, dbSchema
*string, tableName *string, columnName *string, tableType []string) (string,
[]interface{}) {
+ conditions := make([]string, 0)
+ queryArgs := make([]interface{}, 0)
+ if catalog != nil && *catalog != "" {
+ if depth == adbc.ObjectDepthDBSchemas {
+ conditions = append(conditions, ` CATALOG_NAME ILIKE ?
`)
+ } else {
+ conditions = append(conditions, ` TABLE_CATALOG ILIKE ?
`)
+ }
+ queryArgs = append(queryArgs, *catalog)
+ }
+ if dbSchema != nil && *dbSchema != "" {
+ if depth == adbc.ObjectDepthDBSchemas {
+ conditions = append(conditions, ` SCHEMA_NAME ILIKE ? `)
+ } else {
+ conditions = append(conditions, ` TABLE_SCHEMA ILIKE ?
`)
+ }
+ queryArgs = append(queryArgs, *dbSchema)
+ }
+ if tableName != nil && *tableName != "" {
+ conditions = append(conditions, ` TABLE_NAME ILIKE ? `)
+ queryArgs = append(queryArgs, *tableName)
+ }
+ if columnName != nil && *columnName != "" {
+ conditions = append(conditions, ` COLUMN_NAME ILIKE ? `)
+ queryArgs = append(queryArgs, *columnName)
+ }
+
+ var tblConditions []string
+ if len(tableType) > 0 {
+ tblConditions = append(conditions, ` TABLE_TYPE IN
('`+strings.Join(tableType, `','`)+`')`)
+ } else {
+ tblConditions = conditions
+ }
+
+ cond := strings.Join(tblConditions, " AND ")
+ return cond, queryArgs
+}
+
func descToField(name, typ, isnull, primary string, comment sql.NullString)
(field arrow.Field, err error) {
field.Name = strings.ToLower(name)
if isnull == "Y" {
@@ -823,12 +953,12 @@ func (c *cnxn) GetOptionDouble(key string) (float64,
error) {
func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema
*string, tableName string) (*arrow.Schema, error) {
tblParts := make([]string, 0, 3)
if catalog != nil {
- tblParts = append(tblParts,
strconv.Quote(strings.ToUpper(*catalog)))
+ tblParts = append(tblParts, strconv.Quote(*catalog))
}
if dbSchema != nil {
- tblParts = append(tblParts,
strconv.Quote(strings.ToUpper(*dbSchema)))
+ tblParts = append(tblParts, strconv.Quote(*dbSchema))
}
- tblParts = append(tblParts, strconv.Quote(strings.ToUpper(tableName)))
+ tblParts = append(tblParts, strconv.Quote(tableName))
fullyQualifiedTable := strings.Join(tblParts, ".")
rows, err := c.sqldb.QueryContext(ctx, `DESC TABLE
`+fullyQualifiedTable)
diff --git a/go/adbc/driver/snowflake/connection_test.go
b/go/adbc/driver/snowflake/connection_test.go
new file mode 100644
index 00000000..983a206d
--- /dev/null
+++ b/go/adbc/driver/snowflake/connection_test.go
@@ -0,0 +1,320 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+ "fmt"
+ "regexp"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestPrepareCatalogsSQL(t *testing.T) {
+ expected := "SHOW TERSE DATABASES"
+ actual := prepareCatalogsSQL()
+
+ assert.Equal(t, expected, actual, "The expected SQL query for catalogs
is not being generated")
+}
+
+func TestPrepareDbSchemasSQLWithNoFilterOneCatalog(t *testing.T) {
+ catalogNames := [1]string{"DEMO_DB"}
+ catalogPattern := ""
+ schemaPattern := ""
+
+ expected := `SELECT CATALOG_NAME, SCHEMA_NAME
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.SCHEMATA
+ )`
+ actual, queryArgs := prepareDbSchemasSQL(catalogNames[:],
&catalogPattern, &schemaPattern)
+
+ println("Query Args", queryArgs)
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for DbSchemas is not being generated")
+}
+
+func TestPrepareDbSchemasSQLWithNoFilter(t *testing.T) {
+ catalogNames := [3]string{"DEMO_DB", "DEMO'DB", "HELLO_DB"}
+ catalogPattern := ""
+ schemaPattern := ""
+
+ expected := `SELECT CATALOG_NAME, SCHEMA_NAME
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"HELLO_DB".INFORMATION_SCHEMA.SCHEMATA
+ )`
+ actual, queryArgs := prepareDbSchemasSQL(catalogNames[:],
&catalogPattern, &schemaPattern)
+
+ println("Query Args", queryArgs)
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for DbSchemas is not being generated")
+}
+
+func TestPrepareDbSchemasSQLWithCatalogFilter(t *testing.T) {
+ catalogNames := [3]string{"DEMO_DB", "DEMO'DB", "HELLO_DB"}
+ catalogPattern := "DEMO_DB"
+ schemaPattern := ""
+
+ expected := `SELECT CATALOG_NAME, SCHEMA_NAME
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"HELLO_DB".INFORMATION_SCHEMA.SCHEMATA
+ )
+ WHERE CATALOG_NAME ILIKE ? `
+
+ actual, queryArgs := prepareDbSchemasSQL(catalogNames[:],
&catalogPattern, &schemaPattern)
+
+ stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
+ for index := range queryArgs {
+ stringqueryArgs[index] = fmt.Sprintf("%v", queryArgs[index])
+ }
+
+ assert.True(t, areStringsEquivalent(catalogPattern,
strings.Join(stringqueryArgs, ",")), "The expected CATALOG_NAME is not being
generated")
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for DbSchemas is not being generated")
+}
+
+func TestPrepareDbSchemasSQLWithSchemaFilter(t *testing.T) {
+ catalogNames := [3]string{"DEMO_DB", "DEMO'DB", "HELLO_DB"}
+ catalogPattern := ""
+ schemaPattern := "PUBLIC"
+
+ expected := `SELECT CATALOG_NAME, SCHEMA_NAME
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"HELLO_DB".INFORMATION_SCHEMA.SCHEMATA
+ )
+ WHERE SCHEMA_NAME ILIKE ? `
+ actual, queryArgs := prepareDbSchemasSQL(catalogNames[:],
&catalogPattern, &schemaPattern)
+
+ stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
+ for index := range queryArgs {
+ stringqueryArgs[index] = fmt.Sprintf("%v", queryArgs[index])
+ }
+
+ assert.True(t, areStringsEquivalent(schemaPattern,
strings.Join(stringqueryArgs, ",")), "The expected SCHEMA_NAME is not being
generated")
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for DbSchemas is not being generated")
+}
+
+func TestPrepareDbSchemasSQL(t *testing.T) {
+ catalogNames := [4]string{"DEMO_DB", "DEMOADB", "DEMO'DB", "HELLO_DB"}
+ catalogPattern := "DEMO_DB"
+ schemaPattern := "PUBLIC"
+
+ expected := `SELECT CATALOG_NAME, SCHEMA_NAME
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.SCHEMATA
+ UNION ALL
+ SELECT * FROM
"HELLO_DB".INFORMATION_SCHEMA.SCHEMATA
+ )
+ WHERE CATALOG_NAME ILIKE ? AND
SCHEMA_NAME ILIKE ? `
+ actual, queryArgs := prepareDbSchemasSQL(catalogNames[:],
&catalogPattern, &schemaPattern)
+
+ stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
+ for index := range queryArgs {
+ stringqueryArgs[index] = fmt.Sprintf("%v", queryArgs[index])
+ }
+
+ assert.True(t, areStringsEquivalent(catalogPattern+","+schemaPattern,
strings.Join(stringqueryArgs, ",")), "The expected SCHEMA_NAME is not being
generated")
+
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for DbSchemas is not being generated")
+}
+
+func TestPrepareTablesSQLWithNoFilter(t *testing.T) {
+ catalogNames := [3]string{"DEMO_DB", "DEMOADB", "DEMO'DB"}
+ catalogPattern := ""
+ schemaPattern := ""
+ tableNamePattern := ""
+ tableType := make([]string, 0)
+
+ expected := `SELECT table_catalog, table_schema, table_name, table_type
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.TABLES
+ UNION ALL
+ SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.TABLES
+ UNION ALL
+ SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.TABLES
+ )`
+ actual, queryArgs := prepareTablesSQL(catalogNames[:], &catalogPattern,
&schemaPattern, &tableNamePattern, tableType[:])
+
+ println("Query Args", queryArgs)
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for Tables is not being generated")
+}
+
+func TestPrepareTablesSQLWithNoTableTypeFilter(t *testing.T) {
+ catalogNames := [3]string{"DEMO_DB", "DEMOADB", "DEMO'DB"}
+ catalogPattern := "DEMO_DB"
+ schemaPattern := "PUBLIC"
+ tableNamePattern := "ADBC-TABLE"
+ tableType := make([]string, 0)
+
+ expected := `SELECT table_catalog, table_schema, table_name, table_type
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.TABLES
+ UNION ALL
+ SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.TABLES
+ UNION ALL
+ SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.TABLES
+ )
+ WHERE TABLE_CATALOG ILIKE ? AND
TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? `
+ actual, queryArgs := prepareTablesSQL(catalogNames[:], &catalogPattern,
&schemaPattern, &tableNamePattern, tableType[:])
+
+ stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
+ for index := range queryArgs {
+ stringqueryArgs[index] = fmt.Sprintf("%v", queryArgs[index])
+ }
+
+ assert.True(t,
areStringsEquivalent(catalogPattern+","+schemaPattern+","+tableNamePattern,
strings.Join(stringqueryArgs, ",")), "The expected SCHEMA_NAME is not being
generated")
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for Tables is not being generated")
+}
+
+func TestPrepareTablesSQL(t *testing.T) {
+ catalogNames := [3]string{"DEMO_DB", "DEMOADB", "DEMO'DB"}
+ catalogPattern := "DEMO_DB"
+ schemaPattern := "PUBLIC"
+ tableNamePattern := "ADBC-TABLE"
+ tableType := [2]string{"BASE TABLE", "VIEW"}
+
+ expected := `SELECT table_catalog, table_schema, table_name, table_type
+ FROM
+ (
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.TABLES
+ UNION ALL
+ SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.TABLES
+ UNION ALL
+ SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.TABLES
+ )
+ WHERE TABLE_CATALOG ILIKE ? AND
TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? AND TABLE_TYPE IN ('BASE
TABLE','VIEW')`
+ actual, queryArgs := prepareTablesSQL(catalogNames[:], &catalogPattern,
&schemaPattern, &tableNamePattern, tableType[:])
+
+ stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
+ for index := range queryArgs {
+ stringqueryArgs[index] = fmt.Sprintf("%v", queryArgs[index])
+ }
+
+ assert.True(t,
areStringsEquivalent(catalogPattern+","+schemaPattern+","+tableNamePattern,
strings.Join(stringqueryArgs, ",")), "The expected SCHEMA_NAME is not being
generated")
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for Tables is not being generated")
+}
+
+func TestPrepareColumnsSQLNoFilter(t *testing.T) {
+ catalogNames := [2]string{"DEMO_DB", "DEMOADB"}
+ catalogPattern := ""
+ schemaPattern := ""
+ tableNamePattern := ""
+ columnNamePattern := ""
+ tableType := make([]string, 0)
+
+ expected := `SELECT table_type, table_catalog, table_schema,
table_name, column_name,
+ ordinal_position,
is_nullable::boolean, data_type, numeric_precision,
+ numeric_precision_radix,
numeric_scale, is_identity::boolean,
+ identity_generation,
identity_increment,
+ character_maximum_length,
character_octet_length, datetime_precision, comment
+ FROM
+ (
+ SELECT T.table_type, C.*
+ FROM
+
"DEMO_DB".INFORMATION_SCHEMA.TABLES AS T
+ JOIN
+
"DEMO_DB".INFORMATION_SCHEMA.COLUMNS AS C
+ ON
+ T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ UNION ALL
+ SELECT T.table_type, C.*
+ FROM
+
"DEMOADB".INFORMATION_SCHEMA.TABLES AS T
+ JOIN
+
"DEMOADB".INFORMATION_SCHEMA.COLUMNS AS C
+ ON
+ T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ )
+ ORDER BY table_catalog,
table_schema, table_name, ordinal_position`
+ actual, queryArgs := prepareColumnsSQL(catalogNames[:],
&catalogPattern, &schemaPattern, &tableNamePattern, &columnNamePattern,
tableType[:])
+
+ println("Query Args", queryArgs)
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for Tables is not being generated")
+}
+
+func TestPrepareColumnsSQL(t *testing.T) {
+ catalogNames := [2]string{"DEMO_DB", "DEMOADB"}
+ catalogPattern := "DEMO_DB"
+ schemaPattern := "PUBLIC"
+ tableNamePattern := "ADBC-TABLE"
+ columnNamePattern := "creationDate"
+ tableType := [2]string{"BASE TABLE", "VIEW"}
+
+ expected := `SELECT table_type, table_catalog, table_schema,
table_name, column_name,
+ ordinal_position,
is_nullable::boolean, data_type, numeric_precision,
+ numeric_precision_radix,
numeric_scale, is_identity::boolean,
+ identity_generation,
identity_increment,
+ character_maximum_length,
character_octet_length, datetime_precision, comment
+ FROM
+ (
+ SELECT T.table_type, C.*
+ FROM
+
"DEMO_DB".INFORMATION_SCHEMA.TABLES AS T
+ JOIN
+
"DEMO_DB".INFORMATION_SCHEMA.COLUMNS AS C
+ ON
+ T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ UNION ALL
+ SELECT T.table_type, C.*
+ FROM
+
"DEMOADB".INFORMATION_SCHEMA.TABLES AS T
+ JOIN
+
"DEMOADB".INFORMATION_SCHEMA.COLUMNS AS C
+ ON
+ T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ )
+ WHERE TABLE_CATALOG ILIKE ?
AND TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? AND COLUMN_NAME ILIKE ? AND
TABLE_TYPE IN ('BASE TABLE','VIEW')
+ ORDER BY table_catalog,
table_schema, table_name, ordinal_position`
+ actual, queryArgs := prepareColumnsSQL(catalogNames[:],
&catalogPattern, &schemaPattern, &tableNamePattern, &columnNamePattern,
tableType[:])
+
+ stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
+ for index := range queryArgs {
+ stringqueryArgs[index] = fmt.Sprintf("%v", queryArgs[index])
+ }
+
+ assert.True(t,
areStringsEquivalent(catalogPattern+","+schemaPattern+","+tableNamePattern+","+columnNamePattern,
strings.Join(stringqueryArgs, ",")), "The expected SCHEMA_NAME is not being
generated")
+ assert.True(t, areStringsEquivalent(expected, actual), "The expected
SQL query for Tables is not being generated")
+}
+
+func areStringsEquivalent(str1 string, str2 string) bool {
+ re := regexp.MustCompile(`\s+`)
+ normalizedStr1 := re.ReplaceAllString(str1, "")
+ normalizedStr2 := re.ReplaceAllString(str2, "")
+
+ return normalizedStr1 == normalizedStr2
+}