This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 5471d955c feat(go/adbc/driver/snowflake): improve GetObjects
performance and semantics (#2254)
5471d955c is described below
commit 5471d955c3e94f68aced979ddf8ab6d98fd9a098
Author: Matt Topol <[email protected]>
AuthorDate: Thu Oct 17 12:44:26 2024 -0600
feat(go/adbc/driver/snowflake): improve GetObjects performance and
semantics (#2254)
Fixes #2171
Improves the channel handling and query building for metadata conversion
to Arrow for better performance.
For all cases except when retrieving Column metadata we'll now utilize
`SHOW` queries and build the patterns into those queries. This allows
those `GetObjects` calls with appropriate depths to be called without
having to specify a current database or schema.
---
c/driver/flightsql/sqlite_flightsql_test.cc | 1 +
c/driver/snowflake/snowflake_test.cc | 12 +-
c/validation/adbc_validation.h | 6 +
c/validation/adbc_validation_connection.cc | 27 +-
c/validation/adbc_validation_statement.cc | 16 +-
go/adbc/driver/internal/driverbase/connection.go | 51 ++--
go/adbc/driver/snowflake/connection.go | 334 ++++++++++++++-------
go/adbc/driver/snowflake/driver_test.go | 11 +-
.../driver/snowflake/queries/get_objects_all.sql | 18 +-
.../snowflake/queries/get_objects_catalogs.sql | 25 --
.../snowflake/queries/get_objects_dbschemas.sql | 28 +-
.../snowflake/queries/get_objects_tables.sql | 112 ++-----
go/adbc/driver/snowflake/statement.go | 4 +-
13 files changed, 365 insertions(+), 280 deletions(-)
diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc
b/c/driver/flightsql/sqlite_flightsql_test.cc
index 454ea0297..40601e280 100644
--- a/c/driver/flightsql/sqlite_flightsql_test.cc
+++ b/c/driver/flightsql/sqlite_flightsql_test.cc
@@ -121,6 +121,7 @@ class SqliteFlightSqlQuirks : public
adbc_validation::DriverQuirks {
bool supports_get_objects() const override { return true; }
bool supports_partitioned_data() const override { return true; }
bool supports_dynamic_parameter_binding() const override { return true; }
+ std::string catalog() const { return "main"; }
};
class SqliteFlightSqlTest : public ::testing::Test, public
adbc_validation::DatabaseTest {
diff --git a/c/driver/snowflake/snowflake_test.cc
b/c/driver/snowflake/snowflake_test.cc
index 60003353d..262286192 100644
--- a/c/driver/snowflake/snowflake_test.cc
+++ b/c/driver/snowflake/snowflake_test.cc
@@ -99,7 +99,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
adbc_validation::Handle<struct AdbcStatement> statement;
CHECK_OK(AdbcStatementNew(connection, &statement.value, error));
- std::string create = "CREATE TABLE \"";
+ std::string create = "CREATE OR REPLACE TABLE \"";
create += name;
create += "\" (int64s INT, strings TEXT)";
CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(),
error));
@@ -131,7 +131,13 @@ class SnowflakeQuirks : public
adbc_validation::DriverQuirks {
return NANOARROW_TYPE_DOUBLE;
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
+ case NANOARROW_TYPE_LIST:
+ case NANOARROW_TYPE_LARGE_LIST:
return NANOARROW_TYPE_STRING;
+ case NANOARROW_TYPE_BINARY:
+ case NANOARROW_TYPE_LARGE_BINARY:
+ case NANOARROW_TYPE_FIXED_SIZE_BINARY:
+ return NANOARROW_TYPE_BINARY;
default:
return ingest_type;
}
@@ -149,7 +155,11 @@ class SnowflakeQuirks : public
adbc_validation::DriverQuirks {
bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
+ bool supports_ingest_view_types() const override { return false; }
+ bool supports_ingest_float16() const override { return false; }
+
std::string db_schema() const override { return schema_; }
+ std::string catalog() const override { return "ADBC_TESTING"; }
const char* uri_;
bool skip_{false};
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index fa3c1cdcc..f8ef350cc 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -238,6 +238,12 @@ class DriverQuirks {
/// column matching.
virtual bool supports_error_on_incompatible_schema() const { return true; }
+ /// \brief Whether ingestion supports StringView/BinaryView types
+ virtual bool supports_ingest_view_types() const { return true; }
+
+ /// \brief Whether ingestion supports Float16
+ virtual bool supports_ingest_float16() const { return true; }
+
/// \brief Default catalog to use for tests
virtual std::string catalog() const { return ""; }
diff --git a/c/validation/adbc_validation_connection.cc
b/c/validation/adbc_validation_connection.cc
index a885fa2c8..032f1d328 100644
--- a/c/validation/adbc_validation_connection.cc
+++ b/c/validation/adbc_validation_connection.cc
@@ -744,13 +744,15 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
struct TestCase {
std::optional<std::string> filter;
- std::vector<std::string> column_names;
- std::vector<int32_t> ordinal_positions;
+ // the pair is column name & ordinal position of the column
+ std::vector<std::pair<std::string, int32_t>> columns;
};
std::vector<TestCase> test_cases;
- test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}});
- test_cases.push_back({"in%", {"int64s"}, {1}});
+ test_cases.push_back({std::nullopt, {{"int64s", 1}, {"strings", 2}}});
+ test_cases.push_back({"in%", {{"int64s", 1}}});
+
+ const std::string catalog = quirks()->catalog();
for (const auto& test_case : test_cases) {
std::string scope = "Filter: ";
@@ -758,13 +760,14 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
SCOPED_TRACE(scope);
StreamReader reader;
+ std::vector<std::pair<std::string, int32_t>> columns;
std::vector<std::string> column_names;
std::vector<int32_t> ordinal_positions;
ASSERT_THAT(
AdbcConnectionGetObjects(
- &connection, ADBC_OBJECT_DEPTH_COLUMNS, nullptr, nullptr, nullptr,
nullptr,
- test_case.filter.has_value() ? test_case.filter->c_str() : nullptr,
+ &connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr,
nullptr,
+ nullptr, test_case.filter.has_value() ? test_case.filter->c_str()
: nullptr,
&reader.stream.value, &error),
IsOkStatus(&error));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
@@ -834,10 +837,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
std::string temp(name.data, name.size_bytes);
std::transform(temp.begin(), temp.end(), temp.begin(),
[](unsigned char c) { return std::tolower(c);
});
- column_names.push_back(std::move(temp));
- ordinal_positions.push_back(
- static_cast<int32_t>(ArrowArrayViewGetIntUnsafe(
- table_columns->children[1], columns_index)));
+ columns.emplace_back(std::move(temp),
+
static_cast<int32_t>(ArrowArrayViewGetIntUnsafe(
+ table_columns->children[1],
columns_index)));
}
}
}
@@ -847,8 +849,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
} while (reader.array->release);
ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata";
- ASSERT_EQ(test_case.column_names, column_names);
- ASSERT_EQ(test_case.ordinal_positions, ordinal_positions);
+ // metadata columns do not guarantee the order they are returned in, just
+ // validate all the elements are there.
+ ASSERT_THAT(columns,
testing::UnorderedElementsAreArray(test_case.columns));
}
}
diff --git a/c/validation/adbc_validation_statement.cc
b/c/validation/adbc_validation_statement.cc
index 07ab0b22a..150aabf32 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -246,6 +246,10 @@ void StatementTest::TestSqlIngestInt64() {
}
void StatementTest::TestSqlIngestFloat16() {
+ if (!quirks()->supports_ingest_float16()) {
+ GTEST_SKIP();
+ }
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<float>(NANOARROW_TYPE_HALF_FLOAT));
}
@@ -268,6 +272,10 @@ void StatementTest::TestSqlIngestLargeString() {
}
void StatementTest::TestSqlIngestStringView() {
+ if (!quirks()->supports_ingest_view_types()) {
+ GTEST_SKIP();
+ }
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING_VIEW, {std::nullopt, "", "", "longer than 12
bytes", "δΎ‹"},
false));
@@ -302,6 +310,10 @@ void StatementTest::TestSqlIngestFixedSizeBinary() {
}
void StatementTest::TestSqlIngestBinaryView() {
+ if (!quirks()->supports_ingest_view_types()) {
+ GTEST_SKIP();
+ }
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::vector<std::byte>>(
NANOARROW_TYPE_LARGE_BINARY,
{std::nullopt, std::vector<std::byte>{},
@@ -2218,7 +2230,7 @@ void StatementTest::TestSqlBind() {
ASSERT_THAT(
AdbcStatementSetSqlQuery(
- &statement, "SELECT * FROM bindtest ORDER BY \"col1\" ASC NULLS
FIRST", &error),
+ &statement, "SELECT * FROM bindtest ORDER BY col1 ASC NULLS FIRST",
&error),
IsOkStatus(&error));
{
StreamReader reader;
@@ -2226,7 +2238,7 @@ void StatementTest::TestSqlBind() {
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
- ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
+ ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1)));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
diff --git a/go/adbc/driver/internal/driverbase/connection.go
b/go/adbc/driver/internal/driverbase/connection.go
index 6e7881635..b09f74e30 100644
--- a/go/adbc/driver/internal/driverbase/connection.go
+++ b/go/adbc/driver/internal/driverbase/connection.go
@@ -349,14 +349,17 @@ func (cnxn *connection) GetObjects(ctx context.Context,
depth adbc.ObjectDepth,
bufferSize := len(catalogs)
addCatalogCh := make(chan GetObjectsInfo, bufferSize)
- for _, cat := range catalogs {
- addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)}
- }
-
- close(addCatalogCh)
+ errCh := make(chan error, 1)
+ go func() {
+ defer close(addCatalogCh)
+ for _, cat := range catalogs {
+ addCatalogCh <- GetObjectsInfo{CatalogName:
Nullable(cat)}
+ }
+ }()
if depth == adbc.ObjectDepthCatalogs {
- return BuildGetObjectsRecordReader(cnxn.Base().Alloc,
addCatalogCh)
+ close(errCh)
+ return BuildGetObjectsRecordReader(cnxn.Base().Alloc,
addCatalogCh, errCh)
}
g, ctxG := errgroup.WithContext(ctx)
@@ -386,7 +389,7 @@ func (cnxn *connection) GetObjects(ctx context.Context,
depth adbc.ObjectDepth,
g.Go(func() error { defer close(addDbSchemasCh); return gSchemas.Wait()
})
if depth == adbc.ObjectDepthDBSchemas {
- rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc,
addDbSchemasCh)
+ rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc,
addDbSchemasCh, errCh)
return rdr, errors.Join(err, g.Wait())
}
@@ -432,7 +435,7 @@ func (cnxn *connection) GetObjects(ctx context.Context,
depth adbc.ObjectDepth,
g.Go(func() error { defer close(addTablesCh); return gTables.Wait() })
- rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh)
+ rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh,
errCh)
return rdr, errors.Join(err, g.Wait())
}
@@ -621,20 +624,20 @@ type ColumnInfo struct {
type TableInfo struct {
TableName string `json:"table_name"`
TableType string `json:"table_type"`
- TableColumns []ColumnInfo `json:"table_columns,omitempty"`
- TableConstraints []ConstraintInfo `json:"table_constraints,omitempty"`
+ TableColumns []ColumnInfo `json:"table_columns"`
+ TableConstraints []ConstraintInfo `json:"table_constraints"`
}
// DBSchemaInfo is a structured representation of adbc.DBSchemaSchema
type DBSchemaInfo struct {
DbSchemaName *string `json:"db_schema_name,omitempty"`
- DbSchemaTables []TableInfo `json:"db_schema_tables,omitempty"`
+ DbSchemaTables []TableInfo `json:"db_schema_tables"`
}
// GetObjectsInfo is a structured representation of adbc.GetObjectsSchema
type GetObjectsInfo struct {
CatalogName *string `json:"catalog_name,omitempty"`
- CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas,omitempty"`
+ CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas"`
}
// Scan implements sql.Scanner.
@@ -659,23 +662,33 @@ func (g *GetObjectsInfo) Scan(src any) error {
// BuildGetObjectsRecordReader constructs a RecordReader for the GetObjects
ADBC method.
// It accepts a channel of GetObjectsInfo to allow concurrent retrieval of
metadata and
// serialization to Arrow record.
-func BuildGetObjectsRecordReader(mem memory.Allocator, in chan GetObjectsInfo)
(array.RecordReader, error) {
+func BuildGetObjectsRecordReader(mem memory.Allocator, in <-chan
GetObjectsInfo, errCh <-chan error) (array.RecordReader, error) {
bldr := array.NewRecordBuilder(mem, adbc.GetObjectsSchema)
defer bldr.Release()
- for catalog := range in {
- b, err := json.Marshal(catalog)
- if err != nil {
- return nil, err
- }
+CATALOGLOOP:
+ for {
+ select {
+ case catalog, ok := <-in:
+ if !ok {
+ break CATALOGLOOP
+ }
+ b, err := json.Marshal(catalog)
+ if err != nil {
+ return nil, err
+ }
- if err := json.Unmarshal(b, bldr); err != nil {
+ if err := json.Unmarshal(b, bldr); err != nil {
+ return nil, err
+ }
+ case err := <-errCh:
return nil, err
}
}
rec := bldr.NewRecord()
defer rec.Release()
+
return array.NewRecordReader(adbc.GetObjectsSchema, []arrow.Record{rec})
}
diff --git a/go/adbc/driver/snowflake/connection.go
b/go/adbc/driver/snowflake/connection.go
index a8361a365..190426c7f 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -24,7 +24,9 @@ import (
"embed"
"fmt"
"io"
+ "io/fs"
"path"
+ "runtime"
"strconv"
"strings"
"time"
@@ -42,7 +44,6 @@ const (
defaultPrefetchConcurrency = 10
queryTemplateGetObjectsAll = "get_objects_all.sql"
- queryTemplateGetObjectsCatalogs = "get_objects_catalogs.sql"
queryTemplateGetObjectsDbSchemas = "get_objects_dbschemas.sql"
queryTemplateGetObjectsTables = "get_objects_tables.sql"
queryTemplateGetObjectsTerseCatalogs = "get_objects_terse_catalogs.sql"
@@ -73,9 +74,105 @@ type connectionImpl struct {
useHighPrecision bool
}
-func (c *connectionImpl) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string) (array.RecordReader, error) {
+func escapeSingleQuoteForLike(arg string) string {
+ if len(arg) == 0 {
+ return arg
+ }
+
+ idx := strings.IndexByte(arg, '\'')
+ if idx == -1 {
+ return arg
+ }
+
+ var b strings.Builder
+ b.Grow(len(arg))
+
+ for {
+ before, after, found := strings.Cut(arg, `'`)
+ b.WriteString(before)
+ if !found {
+ return b.String()
+ }
+
+ if before[len(before)-1] != '\\' {
+ b.WriteByte('\\')
+ }
+ b.WriteByte('\'')
+ arg = after
+ }
+}
+
+func getQueryID(ctx context.Context, query string, driverConn any) (string,
error) {
+ rows, err := driverConn.(driver.QueryerContext).QueryContext(ctx,
query, nil)
+ if err != nil {
+ return "", err
+ }
+
+ return rows.(gosnowflake.SnowflakeRows).GetQueryID(), rows.Close()
+}
+
+const (
+ objSchemas = "SCHEMAS"
+ objDatabases = "DATABASES"
+ objViews = "VIEWS"
+ objTables = "TABLES"
+ objObjects = "OBJECTS"
+)
+
+func addLike(query string, pattern *string) string {
+ if pattern != nil && len(*pattern) > 0 && *pattern != "%" && *pattern
!= ".*" {
+ query += " LIKE '" + escapeSingleQuoteForLike(*pattern) + "'"
+ }
+ return query
+}
+
+func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group,
objType string, catalog, dbSchema, tableName *string, outQueryID *string) {
+ grp.Go(func() error {
+ return conn.Raw(func(driverConn any) (err error) {
+ query := "SHOW TERSE /* ADBC:getObjects */ " + objType
+ switch objType {
+ case objDatabases:
+ query = addLike(query, catalog)
+ query += " IN ACCOUNT"
+ case objSchemas:
+ query = addLike(query, dbSchema)
+
+ if catalog == nil || isWildcardStr(*catalog) {
+ query += " IN ACCOUNT"
+ } else {
+ query += " IN DATABASE " +
quoteTblName(*catalog)
+ }
+ case objViews, objTables, objObjects:
+ query = addLike(query, tableName)
+
+ if catalog == nil || isWildcardStr(*catalog) {
+ query += " IN ACCOUNT"
+ } else {
+ escapedCatalog := quoteTblName(*catalog)
+ if dbSchema == nil ||
isWildcardStr(*dbSchema) {
+ query += " IN DATABASE " +
escapedCatalog
+ } else {
+ query += " IN SCHEMA " +
escapedCatalog + "." + quoteTblName(*dbSchema)
+ }
+ }
+ default:
+ return fmt.Errorf("unimplemented object type")
+ }
+
+ *outQueryID, err = getQueryID(ctx, query, driverConn)
+ return
+ })
+ })
+}
+
+func isWildcardStr(ident string) bool {
+ return strings.ContainsAny(ident, "_%")
+}
+
+func (c *connectionImpl) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog, dbSchema, tableName, columnName *string, tableType
[]string) (array.RecordReader, error) {
var (
pkQueryID, fkQueryID, uniqueQueryID, terseDbQueryID string
+ showSchemaQueryID, tableQueryID string
)
conn, err := c.sqldb.Conn(ctx)
@@ -84,81 +181,117 @@ func (c *connectionImpl) GetObjects(ctx context.Context,
depth adbc.ObjectDepth,
}
defer conn.Close()
+ var hasViews, hasTables bool
+ for _, t := range tableType {
+ if strings.EqualFold("VIEW", t) {
+ hasViews = true
+ } else if strings.EqualFold("TABLE", t) {
+ hasTables = true
+ }
+ }
+
+ // force empty result from SHOW TABLES if tableType list is not empty
+ // and does not contain TABLE or VIEW in the list.
+ // we need this because we should have non-null db_schema_tables when
+ // depth is Tables, Columns or All.
+ var badTableType = "tabletypedoesnotexist"
+ if len(tableType) > 0 && depth >= adbc.ObjectDepthTables && !hasViews
&& !hasTables {
+ tableName = &badTableType
+ tableType = []string{"TABLE"}
+ }
+
gQueryIDs, gQueryIDsCtx := errgroup.WithContext(ctx)
queryFile := queryTemplateGetObjectsAll
switch depth {
case adbc.ObjectDepthCatalogs:
- if catalog == nil {
- queryFile = queryTemplateGetObjectsTerseCatalogs
- // if the catalog is null, show the terse databases
- // which doesn't require a database context
- gQueryIDs.Go(func() error {
- return conn.Raw(func(driverConn any) error {
- rows, err :=
driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW TERSE
DATABASES", nil)
- if err != nil {
- return err
- }
-
- terseDbQueryID =
rows.(gosnowflake.SnowflakeRows).GetQueryID()
- return rows.Close()
- })
- })
- } else {
- queryFile = queryTemplateGetObjectsCatalogs
- }
+ queryFile = queryTemplateGetObjectsTerseCatalogs
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
+ catalog, dbSchema, tableName, &terseDbQueryID)
case adbc.ObjectDepthDBSchemas:
queryFile = queryTemplateGetObjectsDbSchemas
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas,
+ catalog, dbSchema, tableName, &showSchemaQueryID)
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
+ catalog, dbSchema, tableName, &terseDbQueryID)
case adbc.ObjectDepthTables:
queryFile = queryTemplateGetObjectsTables
- fallthrough
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas,
+ catalog, dbSchema, tableName, &showSchemaQueryID)
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
+ catalog, dbSchema, tableName, &terseDbQueryID)
+
+ objType := objObjects
+ if len(tableType) == 1 {
+ if strings.EqualFold("VIEW", tableType[0]) {
+ objType = objViews
+ } else if strings.EqualFold("TABLE", tableType[0]) {
+ objType = objTables
+ }
+ }
+
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType,
+ catalog, dbSchema, tableName, &tableQueryID)
default:
+ var suffix string
+ if catalog == nil || isWildcardStr(*catalog) {
+ suffix = " IN ACCOUNT"
+ } else {
+ escapedCatalog := quoteTblName(*catalog)
+ if dbSchema == nil || isWildcardStr(*dbSchema) {
+ suffix = " IN DATABASE " + escapedCatalog
+ } else {
+ escapedSchema := quoteTblName(*dbSchema)
+ if tableName == nil ||
isWildcardStr(*tableName) {
+ suffix = " IN SCHEMA " + escapedCatalog
+ "." + escapedSchema
+ } else {
+ escapedTable := quoteTblName(*tableName)
+ suffix = " IN TABLE " + escapedCatalog
+ "." + escapedSchema + "." + escapedTable
+ }
+ }
+ }
+
// Detailed constraint info not available in information_schema
// Need to dispatch SHOW queries and use conn.Raw to extract
the queryID for reuse in GetObjects query
gQueryIDs.Go(func() error {
- return conn.Raw(func(driverConn any) error {
- rows, err :=
driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW PRIMARY
KEYS", nil)
- if err != nil {
- return err
- }
-
- pkQueryID =
rows.(gosnowflake.SnowflakeRows).GetQueryID()
- return rows.Close()
+ return conn.Raw(func(driverConn any) (err error) {
+ pkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW
PRIMARY KEYS /* ADBC:getObjectsTables */"+suffix, driverConn)
+ return err
})
})
gQueryIDs.Go(func() error {
- return conn.Raw(func(driverConn any) error {
- rows, err :=
driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW IMPORTED
KEYS", nil)
- if err != nil {
- return err
- }
-
- fkQueryID =
rows.(gosnowflake.SnowflakeRows).GetQueryID()
- return rows.Close()
+ return conn.Raw(func(driverConn any) (err error) {
+ fkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW
IMPORTED KEYS /* ADBC:getObjectsTables */"+suffix, driverConn)
+ return err
})
})
gQueryIDs.Go(func() error {
- return conn.Raw(func(driverConn any) error {
- rows, err :=
driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW UNIQUE
KEYS", nil)
- if err != nil {
- return err
- }
-
- uniqueQueryID =
rows.(gosnowflake.SnowflakeRows).GetQueryID()
- return rows.Close()
+ return conn.Raw(func(driverConn any) (err error) {
+ uniqueQueryID, err = getQueryID(gQueryIDsCtx,
"SHOW UNIQUE KEYS /* ADBC:getObjectsTables */"+suffix, driverConn)
+ return err
})
})
- }
- f, err := queryTemplates.Open(path.Join("queries", queryFile))
- if err != nil {
- return nil, err
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases,
+ catalog, dbSchema, tableName, &terseDbQueryID)
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas,
+ catalog, dbSchema, tableName, &showSchemaQueryID)
+
+ objType := objObjects
+ if len(tableType) == 1 {
+ if strings.EqualFold("VIEW", tableType[0]) {
+ objType = objViews
+ } else if strings.EqualFold("TABLE", tableType[0]) {
+ objType = objTables
+ }
+ }
+ goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType,
+ catalog, dbSchema, tableName, &tableQueryID)
}
- defer f.Close()
- var bldr strings.Builder
- if _, err := io.Copy(&bldr, f); err != nil {
+ queryBytes, err := fs.ReadFile(queryTemplates, path.Join("queries",
queryFile))
+ if err != nil {
return nil, err
}
@@ -180,80 +313,71 @@ func (c *connectionImpl) GetObjects(ctx context.Context,
depth adbc.ObjectDepth,
sql.Named("FK_QUERY_ID", fkQueryID),
sql.Named("UNIQUE_QUERY_ID", uniqueQueryID),
sql.Named("SHOW_DB_QUERY_ID", terseDbQueryID),
+ sql.Named("SHOW_SCHEMA_QUERY_ID", showSchemaQueryID),
+ sql.Named("SHOW_TABLE_QUERY_ID", tableQueryID),
}
- // the connection that is used is not the same connection context where
the database may have been set
- // if the caller called SetCurrentCatalog() so need to ensure the
database context is appropriate
- if !isNilOrEmpty(catalog) {
- _, e := conn.ExecContext(context.Background(), fmt.Sprintf("USE
DATABASE %s;", quoteTblName(*catalog)), nil)
- if e != nil {
- return nil, errToAdbcErr(adbc.StatusIO, e)
+ // currently only the Columns / all case still requires a current
database/schema
+ // to be propagated. The rest of the cases all solely use SHOW queries
for the metadata
+ // just as done by the snowflake JDBC driver. In those cases we don't
need to propagate
+ // the current session database/schema.
+ if depth == adbc.ObjectDepthColumns || depth == adbc.ObjectDepthAll {
+ dbname, err := c.GetCurrentCatalog()
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
}
- }
- // the connection that is used is not the same connection context where
the schema may have been set
- // if the caller called SetCurrentDbSchema() so need to ensure the
schema context is appropriate
- if !isNilOrEmpty(dbSchema) {
- _, e2 := conn.ExecContext(context.Background(),
fmt.Sprintf("USE SCHEMA %s;", quoteTblName(*dbSchema)), nil)
- if e2 != nil {
- return nil, errToAdbcErr(adbc.StatusIO, e2)
+ schemaname, err := c.GetCurrentDbSchema()
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+
+ // the connection that is used is not the same connection
context where the database may have been set
+ // if the caller called SetCurrentCatalog() so need to ensure
the database context is appropriate
+ multiCtx, _ := gosnowflake.WithMultiStatement(ctx, 2)
+ _, err = conn.ExecContext(multiCtx, fmt.Sprintf("USE DATABASE
%s; USE SCHEMA %s;", quoteTblName(dbname), quoteTblName(schemaname)))
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
}
}
- query := bldr.String()
+ query := string(queryBytes)
rows, err := conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
defer rows.Close()
- catalogCh := make(chan driverbase.GetObjectsInfo, 1)
- readerCh := make(chan array.RecordReader)
+ catalogCh := make(chan driverbase.GetObjectsInfo, runtime.NumCPU())
errCh := make(chan error)
go func() {
- rdr, err := driverbase.BuildGetObjectsRecordReader(c.Alloc,
catalogCh)
- if err != nil {
- errCh <- err
- }
-
- readerCh <- rdr
- close(readerCh)
- }()
-
- for rows.Next() {
- var getObjectsCatalog driverbase.GetObjectsInfo
- if err := rows.Scan(&getObjectsCatalog); err != nil {
- return nil, errToAdbcErr(adbc.StatusInvalidData, err)
- }
+ defer close(catalogCh)
+ for rows.Next() {
+ var getObjectsCatalog driverbase.GetObjectsInfo
+ if err := rows.Scan(&getObjectsCatalog); err != nil {
+ errCh <- errToAdbcErr(adbc.StatusInvalidData,
err)
+ return
+ }
- // A few columns need additional processing outside of Snowflake
- for i, sch := range getObjectsCatalog.CatalogDbSchemas {
- for j, tab := range sch.DbSchemaTables {
- for k, col := range tab.TableColumns {
- field := c.toArrowField(col)
- xdbcDataType :=
driverbase.ToXdbcDataType(field.Type)
+ // A few columns need additional processing outside of
Snowflake
+ for i, sch := range getObjectsCatalog.CatalogDbSchemas {
+ for j, tab := range sch.DbSchemaTables {
+ for k, col := range tab.TableColumns {
+ field := c.toArrowField(col)
+ xdbcDataType :=
driverbase.ToXdbcDataType(field.Type)
-
getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType
= driverbase.Nullable(int16(field.Type.ID()))
-
getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType
= driverbase.Nullable(int16(xdbcDataType))
+
getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType
= driverbase.Nullable(int16(field.Type.ID()))
+
getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType
= driverbase.Nullable(int16(xdbcDataType))
+ }
}
}
- }
-
- catalogCh <- getObjectsCatalog
- }
- close(catalogCh)
- select {
- case rdr := <-readerCh:
- return rdr, nil
- case err := <-errCh:
- return nil, err
- }
-}
+ catalogCh <- getObjectsCatalog
+ }
+ }()
-func isNilOrEmpty(str *string) bool {
- return str == nil || *str == ""
+ return driverbase.BuildGetObjectsRecordReader(c.Alloc, catalogCh, errCh)
}
// PrepareDriverInfo implements driverbase.DriverInfoPreparer.
@@ -266,7 +390,7 @@ func (c *connectionImpl) PrepareDriverInfo(ctx
context.Context, infoCodes []adbc
// ListTableTypes implements driverbase.TableTypeLister.
func (*connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) {
- return []string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil
+ return []string{"TABLE", "VIEW"}, nil
}
// GetCurrentCatalog implements driverbase.CurrentNamespacer.
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index 895015ffd..c67389ca1 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -1215,15 +1215,15 @@ func (suite *SnowflakeTests) TestSqlIngestMapType() {
[
{
"col_int64": 1,
- "col_map": "{\n \"key_value\": [\n {\n
\"key\": \"key1\",\n \"value\": 1\n }\n ]\n}"
+ "col_map": "{\n \"key1\": 1\n}"
},
{
"col_int64": 2,
- "col_map": "{\n \"key_value\": [\n {\n
\"key\": \"key2\",\n \"value\": 2\n }\n ]\n}"
+ "col_map": "{\n \"key2\": 2\n}"
},
{
"col_int64": 3,
- "col_map": "{\n \"key_value\": [\n {\n
\"key\": \"key3\",\n \"value\": 3\n }\n ]\n}"
+ "col_map": "{\n \"key3\": 3\n}"
}
]
`)))
@@ -2161,6 +2161,9 @@ func (suite *SnowflakeTests) TestGetSetClientConfigFile()
{
func (suite *SnowflakeTests) TestGetObjectsWithNilCatalog() {
// this test demonstrates calling GetObjects with the catalog depth and
a nil catalog
- _, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs,
nil, nil, nil, nil, nil)
+ rdr, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs,
nil, nil, nil, nil, nil)
suite.NoError(err)
+ // test suite validates memory allocator so we need to make sure we call
+ // release on the result reader
+ rdr.Release()
}
diff --git a/go/adbc/driver/snowflake/queries/get_objects_all.sql
b/go/adbc/driver/snowflake/queries/get_objects_all.sql
index 45b807f15..7fc10f2e2 100644
--- a/go/adbc/driver/snowflake/queries/get_objects_all.sql
+++ b/go/adbc/driver/snowflake/queries/get_objects_all.sql
@@ -86,12 +86,12 @@ constraints AS (
table_catalog,
table_schema,
table_name,
- ARRAY_AGG({
+ ARRAY_AGG(NULLIF({
'constraint_name': constraint_name,
'constraint_type': constraint_type,
'constraint_column_names': constraint_column_names,
'constraint_column_usage': constraint_column_usage
- }) table_constraints,
+ }, {})) table_constraints,
FROM (
SELECT * FROM pk_constraints
UNION ALL
@@ -105,12 +105,12 @@ tables AS (
SELECT
table_catalog catalog_name,
table_schema schema_name,
- ARRAY_AGG({
+ ARRAY_AGG(NULLIF({
'table_name': table_name,
'table_type': table_type,
- 'table_columns': table_columns,
- 'table_constraints': table_constraints
- }) db_schema_tables
+ 'table_columns': COALESCE(table_columns, []),
+ 'table_constraints': COALESCE(table_constraints, [])
+ }, {})) db_schema_tables
FROM information_schema.tables
LEFT JOIN columns
USING (table_catalog, table_schema, table_name)
@@ -123,7 +123,7 @@ db_schemas AS (
SELECT
catalog_name,
schema_name,
- db_schema_tables,
+ COALESCE(db_schema_tables, []) db_schema_tables,
FROM information_schema.schemata
LEFT JOIN tables
USING (catalog_name, schema_name)
@@ -132,10 +132,10 @@ db_schemas AS (
SELECT
{
'catalog_name': database_name,
- 'catalog_db_schemas': ARRAY_AGG({
+ 'catalog_db_schemas': ARRAY_AGG(NULLIF({
'db_schema_name': schema_name,
'db_schema_tables': db_schema_tables
- })
+ }, {}))
} get_objects
FROM
information_schema.databases
diff --git a/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql
b/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql
deleted file mode 100644
index ec2cef515..000000000
--- a/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql
+++ /dev/null
@@ -1,25 +0,0 @@
--- Licensed to the Apache Software Foundation (ASF) under one
--- or more contributor license agreements. See the NOTICE file
--- distributed with this work for additional information
--- regarding copyright ownership. The ASF licenses this file
--- to you under the Apache License, Version 2.0 (the
--- "License"); you may not use this file except in compliance
--- with the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing,
--- software distributed under the License is distributed on an
--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
--- KIND, either express or implied. See the License for the
--- specific language governing permissions and limitations
--- under the License.
-
-SELECT
- {
- 'catalog_name': database_name,
- 'catalog_db_schemas': null
- } get_objects
-FROM
- information_schema.databases
-WHERE database_name ILIKE :CATALOG;
diff --git a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql
b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql
index 360a6d083..872118f7c 100644
--- a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql
+++ b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql
@@ -17,22 +17,26 @@
WITH db_schemas AS (
SELECT
- catalog_name,
- schema_name,
- FROM information_schema.schemata
- WHERE catalog_name ILIKE :CATALOG AND schema_name ILIKE :DB_SCHEMA
+ "database_name" as "catalog_name",
+ "name" as "schema_name"
+ FROM table(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID))
+ WHERE "database_name" ILIKE :CATALOG
+), db_info AS (
+ SELECT "name" AS "database_name"
+ FROM table(RESULT_SCAN(:SHOW_DB_QUERY_ID))
+ WHERE "name" ILIKE :CATALOG
)
SELECT
{
- 'catalog_name': database_name,
- 'catalog_db_schemas': ARRAY_AGG({
- 'db_schema_name': schema_name,
+ 'catalog_name': "database_name",
+ 'catalog_db_schemas': ARRAY_AGG(NULLIF({
+ 'db_schema_name': "schema_name",
'db_schema_tables': null
- })
+ }, {}))
} get_objects
FROM
- information_schema.databases
+ db_info
LEFT JOIN db_schemas
-ON database_name = catalog_name
-WHERE database_name ILIKE :CATALOG
-GROUP BY database_name;
+ON "database_name" = "catalog_name"
+WHERE "database_name" ILIKE :CATALOG
+GROUP BY "database_name";
diff --git a/go/adbc/driver/snowflake/queries/get_objects_tables.sql
b/go/adbc/driver/snowflake/queries/get_objects_tables.sql
index b3b16ff51..9d6ce36ed 100644
--- a/go/adbc/driver/snowflake/queries/get_objects_tables.sql
+++ b/go/adbc/driver/snowflake/queries/get_objects_tables.sql
@@ -15,107 +15,41 @@
-- specific language governing permissions and limitations
-- under the License.
-WITH pk_constraints AS (
- SELECT
- "database_name" table_catalog,
- "schema_name" table_schema,
- "table_name" table_name,
- "constraint_name" constraint_name,
- 'PRIMARY KEY' constraint_type,
- ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence")
constraint_column_names,
- [] constraint_column_usage,
- FROM TABLE(RESULT_SCAN(:PK_QUERY_ID))
- WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND
table_name ILIKE :TABLE
- GROUP BY table_catalog, table_schema, table_name, "constraint_name"
-),
-unique_constraints AS (
- SELECT
- "database_name" table_catalog,
- "schema_name" table_schema,
- "table_name" table_name,
- "constraint_name" constraint_name,
- 'UNIQUE' constraint_type,
- ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence")
constraint_column_names,
- [] constraint_column_usage,
- FROM TABLE(RESULT_SCAN(:UNIQUE_QUERY_ID))
- WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND
table_name ILIKE :TABLE
- GROUP BY table_catalog, table_schema, table_name, "constraint_name"
-),
-fk_constraints AS (
- SELECT
- "fk_database_name" table_catalog,
- "fk_schema_name" table_schema,
- "fk_table_name" table_name,
- "fk_name" constraint_name,
- 'FOREIGN KEY' constraint_type,
- ARRAY_AGG("fk_column_name") WITHIN GROUP (ORDER BY "key_sequence")
constraint_column_names,
- ARRAY_AGG({
- 'fk_catalog': "pk_database_name",
- 'fk_db_schema': "pk_schema_name",
- 'fk_table': "pk_table_name",
- 'fk_column_name': "pk_column_name"
- }) WITHIN GROUP (ORDER BY "key_sequence") constraint_column_usage,
- FROM TABLE(RESULT_SCAN(:FK_QUERY_ID))
- WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND
table_name ILIKE :TABLE
- GROUP BY table_catalog, table_schema, table_name, constraint_name
-),
-constraints AS (
- SELECT
- table_catalog,
- table_schema,
- table_name,
- ARRAY_AGG({
- 'constraint_name': constraint_name,
- 'constraint_type': constraint_type,
- 'constraint_column_names': constraint_column_names,
- 'constraint_column_usage': constraint_column_usage
- }) table_constraints,
- FROM (
- SELECT * FROM pk_constraints
- UNION ALL
- SELECT * FROM unique_constraints
- UNION ALL
- SELECT * FROM fk_constraints
- )
- GROUP BY table_catalog, table_schema, table_name
-),
-tables AS (
+WITH tables AS (
SELECT
- table_catalog catalog_name,
- table_schema schema_name,
+ "database_name" "catalog_name",
+ "schema_name" "schema_name",
ARRAY_AGG({
- 'table_name': table_name,
- 'table_type': table_type,
- 'table_constraints': table_constraints,
+ 'table_name': "name",
+ 'table_type': "kind",
+ 'table_constraints': null,
'table_columns': null
}) db_schema_tables
-FROM information_schema.tables
-LEFT JOIN constraints
-USING (table_catalog, table_schema, table_name)
-WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND
table_name ILIKE :TABLE
-GROUP BY table_catalog, table_schema
+FROM TABLE(RESULT_SCAN(:SHOW_TABLE_QUERY_ID))
+WHERE "database_name" ILIKE :CATALOG AND "schema_name" ILIKE :DB_SCHEMA AND
"name" ILIKE :TABLE
+GROUP BY "database_name", "schema_name"
),
db_schemas AS (
SELECT
- catalog_name,
- schema_name,
- db_schema_tables,
- FROM information_schema.schemata
+ "database_name" "catalog_name",
+ "name" "schema_name",
+ COALESCE(db_schema_tables, []) db_schema_tables,
+ FROM TABLE(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID))
LEFT JOIN tables
- USING (catalog_name, schema_name)
- WHERE catalog_name ILIKE :CATALOG AND schema_name ILIKE :DB_SCHEMA
+ ON "database_name" = "catalog_name" AND "name" = tables."schema_name"
+ WHERE "database_name" ILIKE :CATALOG AND "name" ILIKE :DB_SCHEMA
)
SELECT
{
- 'catalog_name': database_name,
- 'catalog_db_schemas': ARRAY_AGG({
- 'db_schema_name': schema_name,
+ 'catalog_name': "name",
+ 'catalog_db_schemas': ARRAY_AGG(NULLIF({
+ 'db_schema_name': db_schemas."schema_name",
'db_schema_tables': db_schema_tables
- })
+ }, {}))
} get_objects
FROM
- information_schema.databases
+ TABLE(RESULT_SCAN(:SHOW_DB_QUERY_ID))
LEFT JOIN db_schemas
-ON database_name = catalog_name
-WHERE database_name ILIKE :CATALOG
-GROUP BY database_name;
+ON "name" = "catalog_name"
+WHERE "name" ILIKE :CATALOG
+GROUP BY "name";
diff --git a/go/adbc/driver/snowflake/statement.go
b/go/adbc/driver/snowflake/statement.go
index 1fd1f658f..574e39045 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -321,9 +321,9 @@ func toSnowflakeType(dt arrow.DataType) string {
case arrow.DECIMAL, arrow.DECIMAL256:
dec := dt.(arrow.DecimalType)
return fmt.Sprintf("NUMERIC(%d,%d)", dec.GetPrecision(),
dec.GetScale())
- case arrow.STRING, arrow.LARGE_STRING:
+ case arrow.STRING, arrow.LARGE_STRING, arrow.STRING_VIEW:
return "text"
- case arrow.BINARY, arrow.LARGE_BINARY:
+ case arrow.BINARY, arrow.LARGE_BINARY, arrow.BINARY_VIEW:
return "binary"
case arrow.FIXED_SIZE_BINARY:
fsb := dt.(*arrow.FixedSizeBinaryType)