This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 73f5aba7a feat(go/adbc/driver/snowflake): added table constraints
implementation for GetObjects API (#1593)
73f5aba7a is described below
commit 73f5aba7af97b435117fa85d8325787b2370ba9a
Author: Ryan Syed <[email protected]>
AuthorDate: Wed Mar 27 10:55:56 2024 -0700
feat(go/adbc/driver/snowflake): added table constraints implementation for
GetObjects API (#1593)
## Issue
Table constraints implementation was missing for GetObjects
## Fix
* Added table constraints implementation, which is returned for
ObjectsDepth of Table and Columns
* Added tests in the interop layer
* Modified existing tests in the `connection_test.go` for generated SQL
statements.
* The performance has been slightly impacted by addition of the table
constraints.
## Design

## Performance
After initial changes:

After additional changes to improve performance:

Before:

---------
Co-authored-by: David Li <[email protected]>
---
.../Metadata/AdbcUsageSchema.cs | 21 ++
.../Metadata/GetObjectsParser.cs | 103 ++++--
...row.Adbc.Tests.Drivers.Interop.Snowflake.csproj | 17 +-
.../Drivers/Interop/Snowflake/ConstraintTests.cs | 247 +++++++++++++
.../Snowflake/Resources/SnowflakeConstraints.sql | 39 +++
.../Interop/Snowflake/SnowflakeTestingUtils.cs | 33 +-
go/adbc/driver/internal/shared_utils.go | 131 ++++++-
go/adbc/driver/snowflake/connection.go | 388 +++++++++++++++++++--
go/adbc/driver/snowflake/connection_test.go | 169 +++++----
9 files changed, 991 insertions(+), 157 deletions(-)
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/AdbcUsageSchema.cs
b/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/AdbcUsageSchema.cs
index c996a59b2..cdf9804a0 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/AdbcUsageSchema.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/AdbcUsageSchema.cs
@@ -26,5 +26,26 @@ namespace Apache.Arrow.Adbc.Tests.Metadata
public string FkTable { get; set; }
public string FkColumnName { get; set; }
+
+ public override bool Equals(object obj)
+ {
+ if (obj == null || obj.GetType() != this.GetType())
+ {
+ return false;
+ }
+
+ var other = (AdbcUsageSchema)obj;
+ return this.FkCatalog == other.FkCatalog && this.FkDbSchema ==
other.FkDbSchema && this.FkTable == other.FkTable && this.FkColumnName ==
other.FkColumnName;
+ }
+
+ public override int GetHashCode()
+ {
+ int hash = 17;
+ hash = hash * 31 + (FkCatalog != null ? FkCatalog.GetHashCode() :
0);
+ hash = hash * 31 + (FkDbSchema != null ? FkDbSchema.GetHashCode()
: 0);
+ hash = hash * 31 + (FkTable != null ? FkTable.GetHashCode() : 0);
+ hash = hash * 31 + (FkColumnName != null ?
FkColumnName.GetHashCode() : 0);
+ return hash;
+ }
}
}
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs
b/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs
index 44af9bb0b..ef0d4fde4 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs
@@ -15,9 +15,8 @@
* limitations under the License.
*/
+using System;
using System.Collections.Generic;
-using System.Linq;
-using System.Text.Unicode;
namespace Apache.Arrow.Adbc.Tests.Metadata
{
@@ -104,25 +103,25 @@ namespace Apache.Arrow.Adbc.Tests.Metadata
List<AdbcColumn> columns = new List<AdbcColumn>();
- StringArray column_name =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "column_name")]; // column_name | utf8 not null
- Int32Array ordinal_position =
(Int32Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "ordinal_position")]; // ordinal_position | int32
- StringArray remarks =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "remarks")]; // remarks | utf8
- Int16Array xdbc_data_type =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_data_type")]; // xdbc_data_type | int16
- StringArray xdbc_type_name =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_type_name")]; // xdbc_type_name | utf8
- Int32Array xdbc_column_size =
(Int32Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_column_size")]; // xdbc_column_size | int32
- Int16Array xdbc_decimal_digits =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_decimal_digits")]; // xdbc_decimal_digits | int16
- Int16Array xdbc_num_prec_radix =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_num_prec_radix")];// xdbc_num_prec_radix | int16
- Int16Array xdbc_nullable =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_nullable")];// xdbc_nullable | int16
- StringArray xdbc_column_def =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_column_def")]; // xdbc_column_def | utf8
- Int16Array xdbc_sql_data_type =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_sql_data_type")];// xdbc_sql_data_type | int16
- Int16Array xdbc_datetime_sub =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_datetime_sub")]; // xdbc_datetime_sub | int16
- Int32Array xdbc_char_octet_length =
(Int32Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_char_octet_length")]; // xdbc_char_octet_length |
int32
- StringArray xdbc_is_nullable =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_is_nullable")]; // xdbc_is_nullable | utf8
- StringArray xdbc_scope_catalog =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_scope_catalog")];// xdbc_scope_catalog | utf8
- StringArray xdbc_scope_schema =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_scope_schema")]; // xdbc_scope_schema | utf8
- StringArray xdbc_scope_table =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_scope_table")]; // xdbc_scope_table | utf8
- BooleanArray xdbc_is_autoincrement =
(BooleanArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_is_autoincrement")]; // xdbc_is_autoincrement | bool
- BooleanArray xdbc_is_generatedcolumn =
(BooleanArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndex(f =>
f.Name == "xdbc_is_generatedcolumn")]; // xdbc_is_generatedcolumn |
bool
+ StringArray column_name =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("column_name")];
// column_name | utf8 not null
+ Int32Array ordinal_position =
(Int32Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("ordinal_position")];
// ordinal_position | int32
+ StringArray remarks =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("remarks")];
// remarks | utf8
+ Int16Array xdbc_data_type =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_data_type")];
// xdbc_data_type | int16
+ StringArray xdbc_type_name =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_type_name")];
// xdbc_type_name | utf8
+ Int32Array xdbc_column_size =
(Int32Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_column_size")];
// xdbc_column_size | int32
+ Int16Array xdbc_decimal_digits =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_decimal_digits")];
// xdbc_decimal_digits | int16
+ Int16Array xdbc_num_prec_radix =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_num_prec_radix")];//
xdbc_num_prec_radix | int16
+ Int16Array xdbc_nullable =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_nullable")];//
xdbc_nullable | int16
+ StringArray xdbc_column_def =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_column_def")];
// xdbc_column_def | utf8
+ Int16Array xdbc_sql_data_type =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_sql_data_type")];//
xdbc_sql_data_type | int16
+ Int16Array xdbc_datetime_sub =
(Int16Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_datetime_sub")];
// xdbc_datetime_sub | int16
+ Int32Array xdbc_char_octet_length =
(Int32Array)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_char_octet_length")];
// xdbc_char_octet_length | int32
+ StringArray xdbc_is_nullable =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_is_nullable")];
// xdbc_is_nullable | utf8
+ StringArray xdbc_scope_catalog =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_scope_catalog")];//
xdbc_scope_catalog | utf8
+ StringArray xdbc_scope_schema =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_scope_schema")];
// xdbc_scope_schema | utf8
+ StringArray xdbc_scope_table =
(StringArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_scope_table")];
// xdbc_scope_table | utf8
+ BooleanArray xdbc_is_autoincrement =
(BooleanArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_is_autoincrement")];
// xdbc_is_autoincrement | bool
+ BooleanArray xdbc_is_generatedcolumn =
(BooleanArray)columnsArray.Fields[StandardSchemas.ColumnSchema.FindIndexOrThrow("xdbc_is_generatedcolumn")];
// xdbc_is_generatedcolumn | bool
for (int i = 0; i < columnsArray.Length; i++)
{
@@ -159,10 +158,10 @@ namespace Apache.Arrow.Adbc.Tests.Metadata
List<AdbcConstraint> constraints = new List<AdbcConstraint>();
- StringArray name =
(StringArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndex(f
=> f.Name == "constraint_name")]; // constraint_name | utf8
- StringArray type =
(StringArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndex(f
=> f.Name == "constraint_type")]; // constraint_type | utf8 not null
- ListArray column_names =
(ListArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndex(f
=> f.Name == "constraint_column_names")]; // constraint_column_names |
list<utf8> not null
- ListArray column_usage =
(ListArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndex(f
=> f.Name == "constraint_column_usage")]; // constraint_column_usage |
list<USAGE_SCHEMA>
+ StringArray name =
(StringArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndexOrThrow("constraint_name")];
// constraint_name | utf8
+ StringArray type =
(StringArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndexOrThrow("constraint_type")];
// constraint_type | utf8 not null
+ ListArray columnNames =
(ListArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndexOrThrow("constraint_column_names")];
// constraint_column_names | list<utf8> not null
+ ListArray columnUsages =
(ListArray)constraintsArray.Fields[StandardSchemas.ConstraintSchema.FindIndexOrThrow("constraint_column_usage")];
// constraint_column_usage | list<USAGE_SCHEMA>
for (int i = 0; i < constraintsArray.Length; i++)
{
@@ -170,17 +169,26 @@ namespace Apache.Arrow.Adbc.Tests.Metadata
c.Name = name.GetString(i);
c.Type = type.GetString(i);
- StringArray col_names = column_names.GetSlicedValues(i) as
StringArray;
- StructArray usage = column_usage.GetSlicedValues(i) as
StructArray;
+ StringArray colNames = columnNames.GetSlicedValues(i) as
StringArray;
+ StructArray usages = columnUsages.GetSlicedValues(i) as
StructArray;
- if (usage != null)
+ if (colNames != null)
{
- for (int j = 0; j < usage.Length; j++)
+ for (int j = 0; j < colNames.Length; j++)
+ {
+ c.ColumnNames.Add(colNames.GetString(j));
+ }
+ }
+
+ if (usages != null)
+ {
+ StringArray fkCatalog =
(StringArray)usages.Fields[StandardSchemas.UsageSchema.FindIndexOrThrow("fk_catalog")];
// fk_catalog | utf8
+ StringArray fkDbSchema =
(StringArray)usages.Fields[StandardSchemas.UsageSchema.FindIndexOrThrow("fk_db_schema")];
//fk_db_schema | utf8
+ StringArray fkTable =
(StringArray)usages.Fields[StandardSchemas.UsageSchema.FindIndexOrThrow("fk_table")];
// fk_table | utf8 not null
+ StringArray fkColumnName =
(StringArray)usages.Fields[StandardSchemas.UsageSchema.FindIndexOrThrow("fk_column_name")];
// fk_column_name | utf8 not null
+
+ for (int j = 0; j < usages.Length; j++)
{
- StringArray fkCatalog =
(StringArray)usage.Fields[StandardSchemas.UsageSchema.FindIndex(f => f.Name ==
"fk_catalog")]; // fk_catalog | utf8
- StringArray fkDbSchema =
(StringArray)usage.Fields[StandardSchemas.UsageSchema.FindIndex(f => f.Name ==
"fk_db_schema")]; //fk_db_schema | utf8
- StringArray fkTable =
(StringArray)usage.Fields[StandardSchemas.UsageSchema.FindIndex(f => f.Name ==
"fk_table")]; // fk_table | utf8 not null
- StringArray fkColumnName =
(StringArray)usage.Fields[StandardSchemas.UsageSchema.FindIndex(f => f.Name ==
"fk_column_name")]; // fk_column_name | utf8 not null
AdbcUsageSchema adbcUsageSchema = new
AdbcUsageSchema();
adbcUsageSchema.FkCatalog = fkCatalog.GetString(j);
@@ -197,4 +205,33 @@ namespace Apache.Arrow.Adbc.Tests.Metadata
return constraints;
}
}
+
+ /// <summary>
+ /// Extension methods for List<Field> type
+ /// </summary>
+ ///
+ public static class FieldExtensions
+ {
+ /// <summary>
+ /// Finds the index of the first field with the provided name in the
list or throws an exception
+ /// </summary>
+ /// <param name="fields">The list of fields</param>
+ /// <param name="name">The field name to look for</param>
+ /// <returns>The index of the first field with the provided
name</returns>
+ /// <exception cref="ArgumentNullException">Thrown if fields argument
is null</exception>
+ /// <exception cref="InvalidOperationException">Thrown if no matching
field is found with the provided name</exception>
+ public static int FindIndexOrThrow(this List<Field> fields, string
name)
+ {
+ if (fields == null)
+ {
+ throw new ArgumentNullException(nameof(fields));
+ }
+ int index = fields.FindIndex(f => f.Name == name);
+ if (index == -1)
+ {
+ throw new InvalidOperationException($"No matching field found
with name: {name}");
+ }
+ return index;
+ }
+ }
}
diff --git
a/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj
b/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj
index 1dab15795..234729a01 100644
---
a/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj
+++
b/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj
@@ -1,7 +1,19 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>net472;net6.0</TargetFrameworks>
</PropertyGroup>
+ <ItemGroup>
+ <None Remove="Resources\SnowflakeConstraints.sql" />
+ <None Remove="Resources\SnowflakeData.sql" />
+ </ItemGroup>
+ <ItemGroup>
+ <EmbeddedResource Include="Resources\SnowflakeConstraints.sql">
+ <CopyToOutputDirectory>Never</CopyToOutputDirectory>
+ </EmbeddedResource>
+ <EmbeddedResource Include="Resources\SnowflakeData.sql">
+ <CopyToOutputDirectory>Never</CopyToOutputDirectory>
+ </EmbeddedResource>
+ </ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageReference Include="xunit" Version="2.7.0" />
@@ -17,9 +29,6 @@
<ProjectReference
Include="..\..\..\Apache.Arrow.Adbc.Tests\Apache.Arrow.Adbc.Tests.csproj" />
</ItemGroup>
<ItemGroup>
- <None Update="Resources\SnowflakeData.sql">
- <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
- </None>
<None Update="Resources\snowflakeconfig.json">
<CopyToOutputDirectory>Never</CopyToOutputDirectory>
</None>
diff --git a/csharp/test/Drivers/Interop/Snowflake/ConstraintTests.cs
b/csharp/test/Drivers/Interop/Snowflake/ConstraintTests.cs
new file mode 100644
index 000000000..be0d325f8
--- /dev/null
+++ b/csharp/test/Drivers/Interop/Snowflake/ConstraintTests.cs
@@ -0,0 +1,247 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Apache.Arrow.Adbc.Tests.Metadata;
+using Apache.Arrow.Adbc.Tests.Xunit;
+using Apache.Arrow.Ipc;
+using Xunit;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
+{
+ [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer",
"Apache.Arrow.Adbc.Tests")]
+ public class ConstraintTests : IClassFixture<ConstraintTestsFixutre>
+ {
+ readonly SnowflakeTestConfiguration _snowflakeTestConfiguration;
+ readonly ConstraintTestsFixutre _fixture;
+ const string PRIMARY_KEY = "PRIMARY KEY";
+ const string UNIQUE = "UNIQUE";
+ const string FOREIGN_KEY = "FOREIGN KEY";
+
+ public ConstraintTests(ConstraintTestsFixutre fixture)
+ {
+
Skip.IfNot(Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE));
+ _snowflakeTestConfiguration =
SnowflakeTestingUtils.TestConfiguration;
+ _fixture = fixture;
+ }
+
+ /// <summary>
+ /// Validates if the driver can call GetObjects with GetObjectsDepth
as All and Column name as a pattern and get the table constraints data
+ /// </summary>
+ [SkippableTheory, Order(2)]
+ [InlineData(PRIMARY_KEY, "SYS_CONSTRAINT_", new string[] { "ID",
"NAME" })]
+ public void CanGetObjectsTableConstraintsWithColumnNameFilter(string
constraintType, string constraintNameStart, string[] columnNames)
+ {
+ // need to add the database
+ string databaseName = _snowflakeTestConfiguration.Metadata.Catalog;
+ string schemaName = _snowflakeTestConfiguration.Metadata.Schema;
+
+ using IArrowArrayStream stream = _fixture._connection.GetObjects(
+ depth: AdbcConnection.GetObjectsDepth.All,
+ catalogPattern: databaseName,
+ dbSchemaPattern: schemaName,
+ tableNamePattern: _fixture._tableName1,
+ tableTypes: _fixture._tableTypes,
+ columnNamePattern: columnNames[0]);
+
+ using RecordBatch recordBatch =
stream.ReadNextRecordBatchAsync().Result;
+
+ List<AdbcCatalog> catalogs =
GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);
+
+ List<AdbcTable> tables = catalogs
+ .Where(c => string.Equals(c.Name, databaseName))
+ .Select(c => c.DbSchemas)
+ .FirstOrDefault()
+ .Where(s => string.Equals(s.Name, schemaName))
+ .Select(s => s.Tables)
+ .FirstOrDefault();
+
+ AdbcTable table = tables.Where((table) =>
string.Equals(table.Name, _fixture._tableName1,
StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
+ Assert.True(table != null, "table should not be null");
+ Assert.True(table.Constraints != null, "table constraints should
not be null");
+
+
+ AdbcConstraint adbcConstraint =
table.Constraints.Where((constraint) => string.Equals(constraint.Type,
constraintType)).FirstOrDefault();
+ Assert.True(adbcConstraint != null, $"{constraintType} should be
present");
+ Assert.StartsWith(constraintNameStart, adbcConstraint.Name);
+ Assert.True(adbcConstraint.ColumnNames.Count ==
columnNames.Length, "constraint column count doesn't match");
+ Assert.True(adbcConstraint.ColumnUsage.Count == 0, "ColumnUsages
is only for foreign key");
+ }
+
+ /// <summary>
+ /// Validates if the driver can call GetObjects with GetObjectsDepth
as Tables with TableName as a pattern and get the table constraints data
+ /// </summary>
+ [SkippableTheory, Order(3)]
+ [InlineData(UNIQUE, "SYS_CONSTRAINT_", new string[] { "ID" }, new
string[] { }, AdbcConnection.GetObjectsDepth.All)]
+ [InlineData(UNIQUE, "SYS_CONSTRAINT_", new string[] { "ID" }, new
string[] { }, AdbcConnection.GetObjectsDepth.Tables)]
+ [InlineData(PRIMARY_KEY, "SYS_CONSTRAINT_", new string[] {
"COMPANY_NAME" }, new string[] { }, AdbcConnection.GetObjectsDepth.All)]
+ [InlineData(PRIMARY_KEY, "SYS_CONSTRAINT_", new string[] {
"COMPANY_NAME" }, new string[] { }, AdbcConnection.GetObjectsDepth.Tables)]
+ [InlineData(FOREIGN_KEY, "ADBC_FKEY", new string[] { "DATABASE_ID",
"DATABASE_NAME" }, new string[] { "ID", "NAME" },
AdbcConnection.GetObjectsDepth.All)]
+ [InlineData(FOREIGN_KEY, "ADBC_FKEY", new string[] { "DATABASE_ID",
"DATABASE_NAME" }, new string[] { "ID", "NAME" },
AdbcConnection.GetObjectsDepth.Tables)]
+ public void CanGetObjectsTableConstraints(string constraintType,
string constraintNameStart, string[] columnNames, string[]
referenceColumnNames, AdbcConnection.GetObjectsDepth depth)
+ {
+ // need to add the database
+ string databaseName = _snowflakeTestConfiguration.Metadata.Catalog;
+ string schemaName = _snowflakeTestConfiguration.Metadata.Schema;
+
+ using IArrowArrayStream stream = _fixture._connection.GetObjects(
+ depth: depth,
+ catalogPattern: databaseName,
+ dbSchemaPattern: schemaName,
+ tableNamePattern: _fixture._tableName2,
+ tableTypes: _fixture._tableTypes,
+ columnNamePattern: null);
+
+ using RecordBatch recordBatch =
stream.ReadNextRecordBatchAsync().Result;
+
+ List<AdbcCatalog> catalogs =
GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);
+
+ List<AdbcTable> tables = catalogs
+ .Where(c => string.Equals(c.Name, databaseName))
+ .Select(c => c.DbSchemas)
+ .FirstOrDefault()
+ .Where(s => string.Equals(s.Name, schemaName))
+ .Select(s => s.Tables)
+ .FirstOrDefault();
+
+ AdbcTable table = tables.Where((table) =>
string.Equals(table.Name, _fixture._tableName2,
StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
+ Assert.True(table != null, "table should not be null");
+ Assert.True(table.Constraints != null, "table constraints should
not be null");
+
+
+ AdbcConstraint adbcConstraint =
table.Constraints.Where((constraint) => string.Equals(constraint.Type,
constraintType)).FirstOrDefault();
+ Assert.True(adbcConstraint != null, $"{constraintType} should be
present");
+ Assert.StartsWith(constraintNameStart, adbcConstraint.Name);
+ Assert.True(adbcConstraint.ColumnNames.Count ==
columnNames.Length, "constraint column count doesn't match");
+ foreach (string columnName in columnNames)
+ {
+ Assert.True(adbcConstraint.ColumnNames.Where((col) =>
string.Equals(col, columnName)).FirstOrDefault() != null, $"{columnName} is not
marked as {constraintType}");
+ }
+ if (constraintType == FOREIGN_KEY)
+ {
+ foreach (string referenceColumnName in referenceColumnNames)
+ {
+ AdbcUsageSchema usageSchemaExpected = new AdbcUsageSchema()
+ {
+ FkCatalog = databaseName,
+ FkDbSchema = schemaName,
+ FkTable = _fixture._tableName1.ToUpper(),
+ FkColumnName = referenceColumnName
+ };
+ Assert.True(adbcConstraint.ColumnUsage.Where((usageSchema)
=> usageSchema.Equals(usageSchemaExpected)).FirstOrDefault() != null,
$"ColumnUsage should be present for '{referenceColumnName}' column in
'{_fixture._tableName1}' table");
+ }
+ }
+ else
+ {
+ Assert.True(adbcConstraint.ColumnUsage.Count == 0,
"ColumnUsages is only for foreign key");
+ }
+ }
+ }
+
+ public class ConstraintTestsFixutre : IDisposable
+ {
+ public readonly string s_testTablePrefix = "ADBCCONSTRAINTTEST_";
+ readonly SnowflakeTestConfiguration _snowflakeTestConfiguration;
+ public readonly AdbcConnection _connection;
+ public readonly AdbcStatement _statement;
+ public readonly List<string> _tableTypes;
+ public readonly string _tableName1;
+ public readonly string _tableName2;
+ private bool _disposed = false;
+
+ private const string SNOWFLAKE_CONSTRAINTS_DATA_RESOURCE =
"Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.Resources.SnowflakeConstraints.sql";
+
+
+ public ConstraintTestsFixutre()
+ {
+
Skip.IfNot(Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE));
+ _snowflakeTestConfiguration =
SnowflakeTestingUtils.TestConfiguration;
+ _tableName1 = s_testTablePrefix +
Guid.NewGuid().ToString().Replace("-", "");
+ _tableName2 = s_testTablePrefix +
Guid.NewGuid().ToString().Replace("-", "");
+ _tableTypes = new List<string> { "BASE TABLE", "VIEW" };
+ Dictionary<string, string> parameters = new Dictionary<string,
string>();
+ Dictionary<string, string> options = new Dictionary<string,
string>();
+ AdbcDriver snowflakeDriver =
SnowflakeTestingUtils.GetSnowflakeAdbcDriver(_snowflakeTestConfiguration, out
parameters);
+ AdbcDatabase adbcDatabase = snowflakeDriver.Open(parameters);
+ _connection = adbcDatabase.Connect(options);
+ _statement = _connection.CreateStatement();
+ CreateTables();
+ }
+
+ private void CreateTables()
+ {
+ string[] queries =
SnowflakeTestingUtils.GetQueries(_snowflakeTestConfiguration,
SNOWFLAKE_CONSTRAINTS_DATA_RESOURCE);
+
+ Dictionary<string, string> placeholderValues = new
Dictionary<string, string>() {
+ {"{ADBC_CONSTRANT_TABLE_1}", _tableName1 },
+ {"{ADBC_CONSTRANT_TABLE_2}", _tableName2 }
+ };
+
+ for (int i = 0; i < queries.Length; i++)
+ {
+ string query = queries[i];
+ foreach (string key in placeholderValues.Keys)
+ {
+ if (query.Contains(key))
+ query = query.Replace(key, placeholderValues[key]);
+ }
+ UpdateResult updateResult = ExecuteUpdateStatement(query);
+ }
+ }
+
+ private void DropCreatedTables()
+ {
+ string[] queries = new string[] {
+ $"DROP TABLE IF EXISTS
{_snowflakeTestConfiguration.Metadata.Catalog}.{_snowflakeTestConfiguration.Metadata.Schema}.{_tableName1}",
+ $"DROP TABLE IF EXISTS
{_snowflakeTestConfiguration.Metadata.Catalog}.{_snowflakeTestConfiguration.Metadata.Schema}.{_tableName2}"
+ };
+
+ for (int i = 0; i < queries.Length; i++)
+ {
+ string query = queries[i];
+ UpdateResult updateResult = ExecuteUpdateStatement(query);
+ }
+ }
+
+ private UpdateResult ExecuteUpdateStatement(string query)
+ {
+ using AdbcStatement statement = _connection.CreateStatement();
+ statement.SqlQuery = query;
+ UpdateResult updateResult = statement.ExecuteUpdate();
+ return updateResult;
+ }
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposing && !_disposed)
+ {
+ DropCreatedTables();
+ _connection?.Dispose();
+ _statement?.Dispose();
+ _disposed = true;
+ }
+ }
+ }
+}
diff --git
a/csharp/test/Drivers/Interop/Snowflake/Resources/SnowflakeConstraints.sql
b/csharp/test/Drivers/Interop/Snowflake/Resources/SnowflakeConstraints.sql
new file mode 100644
index 000000000..886b04a4e
--- /dev/null
+++ b/csharp/test/Drivers/Interop/Snowflake/Resources/SnowflakeConstraints.sql
@@ -0,0 +1,39 @@
+
+ -- Licensed to the Apache Software Foundation (ASF) under one or more
+ -- contributor license agreements. See the NOTICE file distributed with
+ -- this work for additional information regarding copyright ownership.
+ -- The ASF licenses this file to You under the Apache License, Version 2.0
+ -- (the "License"); you may not use this file except in compliance with
+ -- the License. You may obtain a copy of the License at
+
+ -- http://www.apache.org/licenses/LICENSE-2.0
+
+ -- Unless required by applicable law or agreed to in writing, software
+ -- distributed under the License is distributed on an "AS IS" BASIS,
+ -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ -- See the License for the specific language governing permissions and
+ -- limitations under the License.
+
+ CREATE OR REPLACE TABLE {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_CONSTRANT_TABLE_1}
(
+ id INT,
+ name VARCHAR(20),
+ PRIMARY KEY (id, name)
+);
+
+INSERT INTO {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_CONSTRANT_TABLE_1} (id, name)
+VALUES
+(1, 'snowflake'),
+(2, 'spark');
+
+CREATE OR REPLACE TABLE {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_CONSTRANT_TABLE_2} (
+ id INT UNIQUE,
+ company_name VARCHAR(30) PRIMARY KEY,
+ database_id INT,
+ database_name VARCHAR(20),
+ CONSTRAINT ADBC_FKEY FOREIGN KEY (database_id, database_name) REFERENCES
{ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_CONSTRANT_TABLE_1} (id, name)
+);
+
+INSERT INTO {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_CONSTRANT_TABLE_2} (id,
company_name, database_id, database_name)
+VALUES
+(6, 'Snowflake Inc', 1, 'snowflake'),
+(7, 'The Apache Software Foundation', 2, 'spark');
diff --git a/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs
b/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs
index 3aa57345e..94cbfe430 100644
--- a/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs
@@ -19,6 +19,7 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Reflection;
using System.Text;
using Apache.Arrow.Adbc.Drivers.Interop.Snowflake;
using Xunit;
@@ -44,8 +45,10 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
internal class SnowflakeTestingUtils
{
internal static readonly SnowflakeTestConfiguration TestConfiguration;
+ private static readonly Assembly CurrentAssembly;
internal const string SNOWFLAKE_TEST_CONFIG_VARIABLE =
"SNOWFLAKE_TEST_CONFIG_FILE";
+ private const string SNOWFLAKE_DATA_RESOURCE =
"Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.Resources.SnowflakeData.sql";
static SnowflakeTestingUtils()
{
@@ -57,6 +60,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
Console.WriteLine(ex.Message);
}
+
+ CurrentAssembly = Assembly.GetExecutingAssembly();
}
/// <summary>
@@ -134,11 +139,35 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
/// Parses the queries from resources/SnowflakeData.sql
/// </summary>
/// <param name="testConfiguration"><see
cref="SnowflakeTestConfiguration"/></param>
- internal static string[] GetQueries(SnowflakeTestConfiguration
testConfiguration)
+ internal static string[] GetQueries(SnowflakeTestConfiguration
testConfiguration, string resourceName = SNOWFLAKE_DATA_RESOURCE)
{
StringBuilder content = new StringBuilder();
- string[] sql = File.ReadAllLines("resources/SnowflakeData.sql");
+ string[] sql = null;
+
+ try
+ {
+ using (Stream stream =
CurrentAssembly.GetManifestResourceStream(resourceName))
+ {
+ if (stream != null)
+ {
+ using (StreamReader sr = new StreamReader(stream))
+ {
+ sql = sr.ReadToEnd().Split(new[] {
Environment.NewLine }, StringSplitOptions.None);
+ }
+ }
+ else
+ {
+ throw new FileNotFoundException("Embedded resource not
found", resourceName);
+ }
+ }
+ }
+ catch (Exception ex)
+ {
+ Console.WriteLine($"An error occured while reading the
resouce: {resourceName}");
+ Console.WriteLine(ex.Message);
+ throw;
+ }
Dictionary<string, string> placeholderValues = new
Dictionary<string, string>() {
{"{ADBC_CATALOG}", testConfiguration.Metadata.Catalog },
diff --git a/go/adbc/driver/internal/shared_utils.go
b/go/adbc/driver/internal/shared_utils.go
index de48bde0c..b5dcf55ca 100644
--- a/go/adbc/driver/internal/shared_utils.go
+++ b/go/adbc/driver/internal/shared_utils.go
@@ -31,23 +31,46 @@ import (
"github.com/apache/arrow/go/v16/arrow/memory"
)
+const (
+ Unique = "UNIQUE"
+ PrimaryKey = "PRIMARY KEY"
+ ForeignKey = "FOREIGN KEY"
+)
+
type CatalogAndSchema struct {
Catalog, Schema string
}
+type CatalogSchemaTable struct {
+ Catalog, Schema, Table string
+}
+
+type CatalogSchemaTableColumn struct {
+ Catalog, Schema, Table, Column string
+}
+
type TableInfo struct {
Name, TableType string
Schema *arrow.Schema
}
type Metadata struct {
- Created
time.Time
- ColName, DataType
string
- Dbname, Kind, Schema, TblName, TblType, IdentGen, IdentIncrement,
Comment sql.NullString
- OrdinalPos
int
- NumericPrec, NumericPrecRadix, NumericScale, DatetimePrec
sql.NullInt16
- IsNullable, IsIdent
bool
- CharMaxLength, CharOctetLength
sql.NullInt32
+ Created
time.Time
+ ColName, DataType, Dbname, Kind, Schema, TblName, TblType, IdentGen,
IdentIncrement, Comment, ConstraintName, ConstraintType sql.NullString
+ OrdinalPos
int
+ NumericPrec, NumericPrecRadix, NumericScale, DatetimePrec
sql.NullInt16
+ IsNullable, IsIdent
bool
+ CharMaxLength, CharOctetLength
sql.NullInt32
+}
+
+type UsageSchema struct {
+ ForeignKeyCatalog, ForeignKeyDbSchema, ForeignKeyTable,
ForeignKeyColName string
+}
+
+type ConstraintSchema struct {
+ ConstraintName, ConstraintType string
+ ConstraintColumnNames []string
+ ConstraintColumnUsages []UsageSchema
}
type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, schema *string, metadataRecords []Metadata)
(map[string][]string, error)
@@ -99,6 +122,7 @@ type GetObjects struct {
builder *array.RecordBuilder
schemaLookup map[string][]string
tableLookup map[CatalogAndSchema][]TableInfo
+ ConstraintLookup map[CatalogSchemaTable][]ConstraintSchema
MetadataRecords []Metadata
catalogPattern *regexp.Regexp
columnNamePattern *regexp.Regexp
@@ -133,6 +157,17 @@ type GetObjects struct {
xdbcIsAutoincrementBuilder *array.BooleanBuilder
xdbcIsGeneratedcolumnBuilder *array.BooleanBuilder
tableConstraintsBuilder *array.ListBuilder
+ tableConstraintsItems *array.StructBuilder
+ constraintNameBuilder *array.StringBuilder
+ constraintTypeBuilder *array.StringBuilder
+ constraintColumnNameBuilder *array.ListBuilder
+ constraintColumnUsageBuilder *array.ListBuilder
+ constraintColumnNameItems *array.StringBuilder
+ constraintColumnUsageItems *array.StructBuilder
+ columnUsageCatalogBuilder *array.StringBuilder
+ columnUsageSchemaBuilder *array.StringBuilder
+ columnUsageTableBuilder *array.StringBuilder
+ columnUsageColumnBuilder *array.StringBuilder
}
func (g *GetObjects) Init(mem memory.Allocator, getObj GetObjDBSchemasFn,
getTbls GetObjTablesFn) error {
@@ -196,6 +231,17 @@ func (g *GetObjects) Init(mem memory.Allocator, getObj
GetObjDBSchemasFn, getTbl
g.xdbcIsAutoincrementBuilder =
g.tableColumnsItems.FieldBuilder(17).(*array.BooleanBuilder)
g.xdbcIsGeneratedcolumnBuilder =
g.tableColumnsItems.FieldBuilder(18).(*array.BooleanBuilder)
g.tableConstraintsBuilder =
g.dbSchemaTablesItems.FieldBuilder(3).(*array.ListBuilder)
+ g.tableConstraintsItems =
g.tableConstraintsBuilder.ValueBuilder().(*array.StructBuilder)
+ g.constraintNameBuilder =
g.tableConstraintsItems.FieldBuilder(0).(*array.StringBuilder)
+ g.constraintTypeBuilder =
g.tableConstraintsItems.FieldBuilder(1).(*array.StringBuilder)
+ g.constraintColumnNameBuilder =
g.tableConstraintsItems.FieldBuilder(2).(*array.ListBuilder)
+ g.constraintColumnNameItems =
g.constraintColumnNameBuilder.ValueBuilder().(*array.StringBuilder)
+ g.constraintColumnUsageBuilder =
g.tableConstraintsItems.FieldBuilder(3).(*array.ListBuilder)
+ g.constraintColumnUsageItems =
g.constraintColumnUsageBuilder.ValueBuilder().(*array.StructBuilder)
+ g.columnUsageCatalogBuilder =
g.constraintColumnUsageItems.FieldBuilder(0).(*array.StringBuilder)
+ g.columnUsageSchemaBuilder =
g.constraintColumnUsageItems.FieldBuilder(1).(*array.StringBuilder)
+ g.columnUsageTableBuilder =
g.constraintColumnUsageItems.FieldBuilder(2).(*array.StringBuilder)
+ g.columnUsageColumnBuilder =
g.constraintColumnUsageItems.FieldBuilder(3).(*array.StringBuilder)
return nil
}
@@ -246,27 +292,80 @@ func (g *GetObjects) appendDbSchema(catalogName,
dbSchemaName string) {
}
g.dbSchemaTablesBuilder.Append(true)
- for _, tableInfo := range g.tableLookup[CatalogAndSchema{
- Catalog: catalogName,
- Schema: dbSchemaName,
- }] {
- g.appendTableInfo(tableInfo)
+ catalogAndSchema := CatalogAndSchema{Catalog: catalogName, Schema:
dbSchemaName}
+ for _, tableInfo := range g.tableLookup[catalogAndSchema] {
+ g.appendTableInfo(tableInfo, catalogAndSchema)
}
}
-func (g *GetObjects) appendTableInfo(tableInfo TableInfo) {
+func (g *GetObjects) appendTableInfo(tableInfo TableInfo, catalogAndSchema
CatalogAndSchema) {
g.tableNameBuilder.Append(tableInfo.Name)
g.tableTypeBuilder.Append(tableInfo.TableType)
g.dbSchemaTablesItems.Append(true)
+ g.appendTableConstraints(tableInfo, catalogAndSchema)
+ g.appendColumnsInfo(tableInfo)
+}
+
+func (g *GetObjects) appendTableConstraints(tableInfo TableInfo,
catalogAndSchema CatalogAndSchema) {
if g.Depth == adbc.ObjectDepthTables {
- g.tableColumnsBuilder.AppendNull()
g.tableConstraintsBuilder.AppendNull()
return
}
- g.tableColumnsBuilder.Append(true)
- // TODO: unimplemented for now
+
g.tableConstraintsBuilder.Append(true)
+ if len(g.ConstraintLookup) == 0 {
+ // Empty list
+ return
+ }
+
+ catalogSchemaTable := CatalogSchemaTable{Catalog:
catalogAndSchema.Catalog, Schema: catalogAndSchema.Schema, Table:
tableInfo.Name}
+ constraintSchemaData, exists := g.ConstraintLookup[catalogSchemaTable]
+
+ if exists {
+ for _, data := range constraintSchemaData {
+ g.constraintNameBuilder.Append(data.ConstraintName)
+ g.constraintTypeBuilder.Append(data.ConstraintType)
+ g.appendConstraintColumns(data)
+ g.appendConstraintColumnUsages(data)
+ g.tableConstraintsItems.Append(true)
+ }
+ }
+}
+
+func (g *GetObjects) appendConstraintColumns(constraintSchema
ConstraintSchema) {
+ if len(constraintSchema.ConstraintColumnNames) == 0 {
+ g.constraintColumnNameBuilder.AppendNull()
+ } else {
+ g.constraintColumnNameBuilder.Append(true)
+ for _, columnName := range
constraintSchema.ConstraintColumnNames {
+ g.constraintColumnNameItems.Append(columnName)
+ }
+ }
+}
+
+func (g *GetObjects) appendConstraintColumnUsages(constraintSchema
ConstraintSchema) {
+ if len(constraintSchema.ConstraintColumnUsages) == 0 {
+ g.constraintColumnUsageBuilder.AppendNull()
+ } else {
+ g.constraintColumnUsageBuilder.Append(true)
+ for _, columnUsages := range
constraintSchema.ConstraintColumnUsages {
+
g.columnUsageCatalogBuilder.Append(columnUsages.ForeignKeyCatalog)
+
g.columnUsageSchemaBuilder.Append(columnUsages.ForeignKeyDbSchema)
+
g.columnUsageTableBuilder.Append(columnUsages.ForeignKeyTable)
+
g.columnUsageColumnBuilder.Append(columnUsages.ForeignKeyColName)
+ g.constraintColumnUsageItems.Append(true)
+ }
+ }
+}
+
+func (g *GetObjects) appendColumnsInfo(tableInfo TableInfo) {
+ if g.Depth == adbc.ObjectDepthTables {
+ g.tableColumnsBuilder.AppendNull()
+ return
+ }
+
+ g.tableColumnsBuilder.Append(true)
if tableInfo.Schema == nil {
return
diff --git a/go/adbc/driver/snowflake/connection.go
b/go/adbc/driver/snowflake/connection.go
index 4b023b050..9bc3ef54f 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -18,12 +18,14 @@
package snowflake
import (
+ "cmp"
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"regexp"
+ "slices"
"strconv"
"strings"
"time"
@@ -63,6 +65,24 @@ type connectionImpl struct {
useHighPrecision bool
}
+// Uniquely identify a constraint based on the dbName, schema, and tblName
+// As Snowflake allows creating duplicate constraintName in a separate schema
+// Table Name is stored additional for mapping to internal.CatalogSchemaTable
struct
+type QualifiedConstraint struct {
+ catalogSchemaTable internal.CatalogSchemaTable
+ constraintName string
+}
+
+type TableConstraint struct {
+ dbName, schema, tblName, colName, constraintName, constraintType string
+ fkDbName, fkSchema, fkTblName, fkColName, fkConstraintName
sql.NullString
+ skipUpdateRule, skipDeleteRule, skipDeferrability string
+ keySequence int
+ skipComment
sql.NullString
+ skipCreatedOn
time.Time
+ skipRely bool
+}
+
// ListTableTypes implements driverbase.TableTypeLister.
func (*connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) {
return []string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil
@@ -113,7 +133,6 @@ func (c *connectionImpl) SetAutocommit(enabled bool) error {
}
_, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET
AUTOCOMMIT = false", nil)
return err
-
}
// Metadata methods
@@ -228,6 +247,13 @@ func (c *connectionImpl) GetObjects(ctx context.Context,
depth adbc.ObjectDepth,
g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog,
DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType:
tableType}
g.MetadataRecords = metadataRecords
+
+ constraintLookup, err := c.populateConstraintSchema(ctx, depth,
metadataRecords)
+ g.ConstraintLookup = constraintLookup
+ if err != nil {
+ return nil, err
+ }
+
if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
return nil, err
}
@@ -254,15 +280,16 @@ func (c *connectionImpl) getObjectsDbSchemas(ctx
context.Context, depth adbc.Obj
}
result = make(map[string][]string)
- uniqueCatalogSchema := make(map[string]map[string]bool)
+ uniqueCatalogSchema := make(map[internal.CatalogAndSchema]bool)
for _, data := range metadataRecords {
if !data.Dbname.Valid || !data.Schema.Valid {
continue
}
- if _, exists := uniqueCatalogSchema[data.Dbname.String];
!exists {
- uniqueCatalogSchema[data.Dbname.String] =
make(map[string]bool)
+ catalogSchemaInfo := internal.CatalogAndSchema{
+ Catalog: data.Dbname.String,
+ Schema: data.Schema.String,
}
cat, exists := result[data.Dbname.String]
@@ -270,9 +297,9 @@ func (c *connectionImpl) getObjectsDbSchemas(ctx
context.Context, depth adbc.Obj
cat = make([]string, 0, 1)
}
- if _, exists :=
uniqueCatalogSchema[data.Dbname.String][data.Schema.String]; !exists {
+ if _, exists := uniqueCatalogSchema[catalogSchemaInfo]; !exists
{
+ uniqueCatalogSchema[catalogSchemaInfo] = true
result[data.Dbname.String] = append(cat,
data.Schema.String)
-
uniqueCatalogSchema[data.Dbname.String][data.Schema.String] = true
}
}
@@ -442,22 +469,20 @@ func (c *connectionImpl) getObjectsTables(ctx
context.Context, depth adbc.Object
result = make(internal.SchemaToTableInfo)
includeSchema := depth == adbc.ObjectDepthAll || depth ==
adbc.ObjectDepthColumns
- uniqueCatalogSchemaTable := make(map[string]map[string]map[string]bool)
+ uniqueCatalogSchemaTable := make(map[internal.CatalogSchemaTable]bool)
for _, data := range metadataRecords {
if !data.Dbname.Valid || !data.Schema.Valid ||
!data.TblName.Valid || !data.TblType.Valid {
continue
}
- if _, exists := uniqueCatalogSchemaTable[data.Dbname.String];
!exists {
- uniqueCatalogSchemaTable[data.Dbname.String] =
make(map[string]map[string]bool)
- }
-
- if _, exists :=
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String]; !exists {
-
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String] =
make(map[string]bool)
+ catalogSchemaTableInfo := internal.CatalogSchemaTable{
+ Catalog: data.Dbname.String,
+ Schema: data.Schema.String,
+ Table: data.TblName.String,
}
- if _, exists :=
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String][data.TblName.String];
!exists {
-
uniqueCatalogSchemaTable[data.Dbname.String][data.Schema.String][data.TblName.String]
= true
+ if _, exists :=
uniqueCatalogSchemaTable[catalogSchemaTableInfo]; !exists {
+ uniqueCatalogSchemaTable[catalogSchemaTableInfo] = true
key := internal.CatalogAndSchema{
Catalog: data.Dbname.String, Schema:
data.Schema.String}
@@ -474,8 +499,10 @@ func (c *connectionImpl) getObjectsTables(ctx
context.Context, depth adbc.Object
fieldList = make([]arrow.Field, 0)
)
+ uniqueColumn := make(map[internal.CatalogSchemaTableColumn]bool)
+
for _, data := range metadataRecords {
- if !data.Dbname.Valid || !data.Schema.Valid ||
!data.TblName.Valid {
+ if !data.Dbname.Valid || !data.Schema.Valid ||
!data.TblName.Valid || !data.ColName.Valid {
continue
}
@@ -496,7 +523,16 @@ func (c *connectionImpl) getObjectsTables(ctx
context.Context, depth adbc.Object
}
prevKey = key
- fieldList = append(fieldList, toField(data.ColName,
data.IsNullable, data.DataType, data.NumericPrec, data.NumericPrecRadix,
data.NumericScale, data.IsIdent, c.useHighPrecision, data.IdentGen,
data.IdentIncrement, data.CharMaxLength, data.CharOctetLength,
data.DatetimePrec, data.Comment, data.OrdinalPos))
+ columnInfo := internal.CatalogSchemaTableColumn{
+ Catalog: data.Dbname.String,
+ Schema: data.Schema.String,
+ Table: data.TblName.String,
+ Column: data.ColName.String,
+ }
+ if _, exists := uniqueColumn[columnInfo]; !exists {
+ uniqueColumn[columnInfo] = true
+ fieldList = append(fieldList,
toField(data.ColName.String, data.IsNullable, data.DataType.String,
data.NumericPrec, data.NumericPrecRadix, data.NumericScale, data.IsIdent,
c.useHighPrecision, data.IdentGen, data.IdentIncrement, data.CharMaxLength,
data.CharOctetLength, data.DatetimePrec, data.Comment, data.OrdinalPos))
+ }
}
if len(fieldList) > 0 && curTableInfo != nil {
@@ -508,6 +544,7 @@ func (c *connectionImpl) getObjectsTables(ctx
context.Context, depth adbc.Object
func (c *connectionImpl) populateMetadata(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string) ([]internal.Metadata, error) {
var metadataRecords []internal.Metadata
+
catalogMetadataRecords, err := c.getCatalogsMetadata(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
@@ -525,10 +562,19 @@ func (c *connectionImpl) populateMetadata(ctx
context.Context, depth adbc.Object
metadataRecords = catalogMetadataRecords
} else if depth == adbc.ObjectDepthDBSchemas {
metadataRecords, err = c.getDbSchemasMetadata(ctx,
matchingCatalogNames, catalog, dbSchema)
+
} else if depth == adbc.ObjectDepthTables {
metadataRecords, err = c.getTablesMetadata(ctx,
matchingCatalogNames, catalog, dbSchema, tableName, tableType)
} else {
- metadataRecords, err = c.getColumnsMetadata(ctx,
matchingCatalogNames, catalog, dbSchema, tableName, columnName, tableType)
+ tableMetadataRecords, tablesErr := c.getTablesMetadata(ctx,
matchingCatalogNames, catalog, dbSchema, tableName, tableType)
+ if tablesErr != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ columnsMetadataRecords, columnsErr := c.getColumnsMetadata(ctx,
matchingCatalogNames, catalog, dbSchema, tableName, columnName, tableType)
+ if columnsErr != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ metadataRecords = append(tableMetadataRecords,
columnsMetadataRecords...)
}
if err != nil {
@@ -538,6 +584,187 @@ func (c *connectionImpl) populateMetadata(ctx
context.Context, depth adbc.Object
return metadataRecords, nil
}
+func (c *connectionImpl) populateConstraintSchema(ctx context.Context, depth
adbc.ObjectDepth, metadataRecords []internal.Metadata)
(map[internal.CatalogSchemaTable][]internal.ConstraintSchema, error) {
+ constraintLookup :=
make(map[internal.CatalogSchemaTable][]internal.ConstraintSchema)
+ tableConstraintsData, err := c.getConstraintsData(ctx, depth,
metadataRecords)
+ if err != nil {
+ return nil, err
+ }
+
+ // we want to avoid creating duplicate entries for a constraint
+ qualifiedConstraintLookup :=
make(map[QualifiedConstraint]internal.ConstraintSchema)
+ for _, data := range tableConstraintsData {
+ var qualifiedConstraint QualifiedConstraint
+ // columnUsages is only relevant for a foreign key
+ if data.fkConstraintName.Valid {
+ qualifiedConstraint =
getQualifiedConstraint(data.fkDbName.String, data.fkSchema.String,
data.fkTblName.String, data.fkConstraintName.String)
+ if _, exists :=
qualifiedConstraintLookup[qualifiedConstraint]; !exists {
+ qualifiedConstraintLookup[qualifiedConstraint]
= getConstraintSchemaFromTableConstraint(data)
+ } else {
+ constraintInfo :=
qualifiedConstraintLookup[qualifiedConstraint]
+ // appending additional column names and column
usages for foreign key constraints
+ constraintInfo.ConstraintColumnNames =
append(constraintInfo.ConstraintColumnNames, data.fkColName.String)
+ constraintInfo.ConstraintColumnUsages =
append(constraintInfo.ConstraintColumnUsages,
getUsageSchemaFromTableConstraint(data))
+ qualifiedConstraintLookup[qualifiedConstraint]
= constraintInfo
+ }
+ } else {
+ qualifiedConstraint =
getQualifiedConstraint(data.dbName, data.schema, data.tblName,
data.constraintName)
+ if _, exists :=
qualifiedConstraintLookup[qualifiedConstraint]; !exists {
+ qualifiedConstraintLookup[qualifiedConstraint]
= getConstraintSchemaFromTableConstraint(data)
+ } else {
+ constraintInfo :=
qualifiedConstraintLookup[qualifiedConstraint]
+ // appending additional column names for
primary and unique key constraints
+ constraintInfo.ConstraintColumnNames =
append(constraintInfo.ConstraintColumnNames, data.colName)
+ qualifiedConstraintLookup[qualifiedConstraint]
= constraintInfo
+ }
+ }
+ }
+
+ // adding all the unique constraints to a constraint lookup using the
catalogSchemaTable as a key
+ for qualifiedConstraint, constraintSchema := range
qualifiedConstraintLookup {
+ catalogSchemaTable := qualifiedConstraint.catalogSchemaTable
+ constraintLookup[catalogSchemaTable] =
append(constraintLookup[catalogSchemaTable], constraintSchema)
+ }
+
+ return constraintLookup, nil
+}
+
+func (c *connectionImpl) getConstraintsData(ctx context.Context, depth
adbc.ObjectDepth, metadataRecords []internal.Metadata) ([]TableConstraint,
error) {
+ if depth == adbc.ObjectDepthCatalogs || depth ==
adbc.ObjectDepthDBSchemas {
+ return nil, nil
+ }
+ availableConstraintTypes := getAvailableConstraintTypes(metadataRecords)
+ availableFullyQualifiedConstraints :=
getAvailableConstraints(metadataRecords)
+
+ var uniqueConstraintsData []TableConstraint
+ var primaryKeyConstraintsData []TableConstraint
+ var foreignKeyConstraintsData []TableConstraint
+ var err error
+
+ if availableConstraintTypes != nil {
+ if _, exists := availableConstraintTypes[internal.Unique];
exists {
+ uniqueConstraintsData, err =
c.getUniqueConstraints(ctx, availableFullyQualifiedConstraints)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ }
+
+ if _, exists := availableConstraintTypes[internal.PrimaryKey];
exists {
+ primaryKeyConstraintsData, err =
c.getPrimaryKeyConstraints(ctx, availableFullyQualifiedConstraints)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ }
+
+ if _, exists := availableConstraintTypes[internal.ForeignKey];
exists {
+ foreignKeyConstraintsData, err =
c.getForeignKeyConstraints(ctx, availableFullyQualifiedConstraints)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ }
+ }
+
+ tableConstraintsData := append(append(uniqueConstraintsData,
primaryKeyConstraintsData...), foreignKeyConstraintsData...)
+
+ slices.SortFunc(tableConstraintsData, func(i, j TableConstraint) int {
+ if n := cmp.Compare(i.constraintName, j.constraintName); n != 0
{
+ return n
+ }
+ // If constrain names are equal, order by keySequence
+ return cmp.Compare(i.keySequence, j.keySequence)
+ })
+
+ return tableConstraintsData, nil
+}
+
+func (c *connectionImpl) getUniqueConstraints(ctx context.Context,
fullyQualifiedConstraints map[QualifiedConstraint]bool) ([]TableConstraint,
error) {
+ uniqueConstraintsData := make([]TableConstraint, 0)
+
+ rows, err := c.sqldb.QueryContext(ctx, prepareUniqueConstraintSQL(),
nil)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var uniqueConstraint TableConstraint
+ if err := rows.Scan(&uniqueConstraint.skipCreatedOn,
&uniqueConstraint.dbName, &uniqueConstraint.schema,
+ &uniqueConstraint.tblName, &uniqueConstraint.colName,
&uniqueConstraint.keySequence,
+ &uniqueConstraint.constraintName,
&uniqueConstraint.skipRely, &uniqueConstraint.skipComment); err != nil {
+ return nil, errToAdbcErr(adbc.StatusInvalidData, err)
+ }
+
+ currentQualifiedConstraint :=
getQualifiedConstraint(uniqueConstraint.dbName, uniqueConstraint.schema,
uniqueConstraint.tblName, uniqueConstraint.constraintName)
+
+ // skip constraint if it doesn't exist in
fullyQualifiedConstraints
+ if _, exists :=
fullyQualifiedConstraints[currentQualifiedConstraint]; exists {
+ uniqueConstraint.constraintType = internal.Unique
+ uniqueConstraintsData = append(uniqueConstraintsData,
uniqueConstraint)
+ }
+
+ }
+ return uniqueConstraintsData, nil
+}
+
+func (c *connectionImpl) getPrimaryKeyConstraints(ctx context.Context,
fullyQualifiedConstraints map[QualifiedConstraint]bool) ([]TableConstraint,
error) {
+ primaryKeyConstraintsData := make([]TableConstraint, 0)
+
+ rows, err := c.sqldb.QueryContext(ctx,
preparePrimaryKeyConstraintSQL(), nil)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var primaryKeyConstraint TableConstraint
+ if err := rows.Scan(&primaryKeyConstraint.skipCreatedOn,
&primaryKeyConstraint.dbName, &primaryKeyConstraint.schema,
+ &primaryKeyConstraint.tblName,
&primaryKeyConstraint.colName, &primaryKeyConstraint.keySequence,
+ &primaryKeyConstraint.constraintName,
&primaryKeyConstraint.skipRely, &primaryKeyConstraint.skipComment); err != nil {
+ return nil, errToAdbcErr(adbc.StatusInvalidData, err)
+ }
+
+ currentQualifiedConstraint :=
getQualifiedConstraint(primaryKeyConstraint.dbName,
primaryKeyConstraint.schema, primaryKeyConstraint.tblName,
primaryKeyConstraint.constraintName)
+
+ // skip constraint if it doesn't exist in
fullyQualifiedConstraints
+ if _, exists :=
fullyQualifiedConstraints[currentQualifiedConstraint]; exists {
+ primaryKeyConstraint.constraintType =
internal.PrimaryKey
+ primaryKeyConstraintsData =
append(primaryKeyConstraintsData, primaryKeyConstraint)
+ }
+
+ }
+ return primaryKeyConstraintsData, nil
+}
+
+func (c *connectionImpl) getForeignKeyConstraints(ctx context.Context,
qualifiedConstraints map[QualifiedConstraint]bool) ([]TableConstraint, error) {
+ foreignKeyConstraintsData := make([]TableConstraint, 0)
+
+ rows, err := c.sqldb.QueryContext(ctx,
prepareForeignKeyConstraintSQL(), nil)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var fkConstraint TableConstraint
+ if err := rows.Scan(&fkConstraint.skipCreatedOn,
&fkConstraint.dbName, &fkConstraint.schema,
+ &fkConstraint.tblName, &fkConstraint.colName,
&fkConstraint.fkDbName, &fkConstraint.fkSchema,
+ &fkConstraint.fkTblName, &fkConstraint.fkColName,
&fkConstraint.keySequence,
+ &fkConstraint.skipUpdateRule,
&fkConstraint.skipDeleteRule, &fkConstraint.fkConstraintName,
+ &fkConstraint.constraintName,
&fkConstraint.skipDeferrability, &fkConstraint.skipRely,
&fkConstraint.skipComment); err != nil {
+ return nil, errToAdbcErr(adbc.StatusInvalidData, err)
+ }
+
+ currentQualifiedConstaint :=
getQualifiedConstraint(fkConstraint.fkDbName.String,
fkConstraint.fkSchema.String, fkConstraint.fkTblName.String,
fkConstraint.fkConstraintName.String)
+
+ // skip constraint if it doesn't exist in qualifiedConstraints
+ if _, exists :=
qualifiedConstraints[currentQualifiedConstaint]; exists {
+ fkConstraint.constraintType = internal.ForeignKey
+ foreignKeyConstraintsData =
append(foreignKeyConstraintsData, fkConstraint)
+ }
+ }
+ return foreignKeyConstraintsData, nil
+}
+
func (c *connectionImpl) getCatalogsMetadata(ctx context.Context)
([]internal.Metadata, error) {
metadataRecords := make([]internal.Metadata, 0)
@@ -545,6 +772,7 @@ func (c *connectionImpl) getCatalogsMetadata(ctx
context.Context) ([]internal.Me
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
+ defer rows.Close()
for rows.Next() {
var data internal.Metadata
@@ -597,7 +825,7 @@ func (c *connectionImpl) getTablesMetadata(ctx
context.Context, matchingCatalogN
for rows.Next() {
var data internal.Metadata
- if err = rows.Scan(&data.Dbname, &data.Schema, &data.TblName,
&data.TblType); err != nil {
+ if err = rows.Scan(&data.Dbname, &data.Schema, &data.TblName,
&data.TblType, &data.ConstraintName, &data.ConstraintType); err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
metadataRecords = append(metadataRecords, data)
@@ -618,7 +846,7 @@ func (c *connectionImpl) getColumnsMetadata(ctx
context.Context, matchingCatalog
for rows.Next() {
// order here matches the order of the columns requested in the
query
- err = rows.Scan(&data.TblType, &data.Dbname, &data.Schema,
&data.TblName, &data.ColName,
+ err = rows.Scan(&data.Dbname, &data.Schema, &data.TblName,
&data.ColName,
&data.OrdinalPos, &data.IsNullable, &data.DataType,
&data.NumericPrec,
&data.NumericPrecRadix, &data.NumericScale,
&data.IsIdent, &data.IdentGen,
&data.IdentIncrement, &data.CharMaxLength,
&data.CharOctetLength, &data.DatetimePrec, &data.Comment)
@@ -630,6 +858,77 @@ func (c *connectionImpl) getColumnsMetadata(ctx
context.Context, matchingCatalog
return metadataRecords, nil
}
+func getAvailableConstraintTypes(metadataRecords []internal.Metadata)
map[string]bool {
+ availableConstraintType := make(map[string]bool)
+ for _, data := range metadataRecords {
+ if data.ConstraintType.Valid {
+ switch data.ConstraintType.String {
+ case internal.Unique:
+ availableConstraintType[internal.Unique] = true
+ case internal.PrimaryKey:
+ availableConstraintType[internal.PrimaryKey] =
true
+ case internal.ForeignKey:
+ availableConstraintType[internal.ForeignKey] =
true
+ default:
+ }
+ }
+ }
+ return availableConstraintType
+}
+
+func getAvailableConstraints(metadataRecords []internal.Metadata)
map[QualifiedConstraint]bool {
+ qualifiedConstraints := make(map[QualifiedConstraint]bool)
+ for _, data := range metadataRecords {
+ if data.ConstraintName.Valid {
+ qualifiedConstraint :=
getQualifiedConstraint(data.Dbname.String, data.Schema.String,
data.TblName.String, data.ConstraintName.String)
+ qualifiedConstraints[qualifiedConstraint] = true
+ }
+ }
+ return qualifiedConstraints
+}
+
+func getQualifiedConstraint(dbName string, schema string, tableName string,
constraintName string) QualifiedConstraint {
+ return QualifiedConstraint{
+ catalogSchemaTable: internal.CatalogSchemaTable{
+ Catalog: dbName,
+ Schema: schema,
+ Table: tableName,
+ },
+ constraintName: constraintName,
+ }
+}
+
+func getConstraintSchemaFromTableConstraint(tableConstraint TableConstraint)
internal.ConstraintSchema {
+ var constraintSchema internal.ConstraintSchema
+ constraintSchema.ConstraintType = tableConstraint.constraintType
+
+ if tableConstraint.fkConstraintName.Valid {
+ usageSchema :=
getUsageSchemaFromTableConstraint(tableConstraint)
+ constraintSchema.ConstraintName =
tableConstraint.fkConstraintName.String
+ constraintSchema.ConstraintColumnNames =
[]string{tableConstraint.fkColName.String}
+ constraintSchema.ConstraintColumnUsages =
[]internal.UsageSchema{usageSchema}
+ } else {
+ constraintSchema.ConstraintName = tableConstraint.constraintName
+ constraintSchema.ConstraintColumnNames =
[]string{tableConstraint.colName}
+ }
+ return constraintSchema
+}
+
+func getUsageSchemaFromTableConstraint(tableConstraint TableConstraint)
internal.UsageSchema {
+ var usageSchema internal.UsageSchema
+ // usageSchema is only applicable for foreign key constraint
+ if tableConstraint.fkConstraintName.Valid {
+ // reference column for a foreign key constraint
+ usageSchema = internal.UsageSchema{
+ ForeignKeyCatalog: tableConstraint.dbName,
+ ForeignKeyDbSchema: tableConstraint.schema,
+ ForeignKeyTable: tableConstraint.tblName,
+ ForeignKeyColName: tableConstraint.colName,
+ }
+ }
+ return usageSchema
+}
+
func getMatchingCatalogNames(metadataRecords []internal.Metadata, catalog
*string) ([]string, error) {
matchingCatalogNames := make([]string, 0)
var catalogPattern *regexp.Regexp
@@ -682,10 +981,20 @@ func prepareTablesSQL(matchingCatalogNames []string,
catalog *string, dbSchema *
if query != "" {
query += " UNION ALL "
}
- query += `SELECT * FROM "` + strings.ReplaceAll(catalog_name,
"\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES`
- }
-
- query = `SELECT table_catalog, table_schema, table_name, table_type
FROM (` + query + `)`
+ query += `SELECT T.table_catalog, T.table_schema, T.table_name,
T.table_type, TC.constraint_name, TC.constraint_type
+ FROM
+ (
+ "` +
strings.ReplaceAll(catalog_name, "\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES
AS T
+ LEFT JOIN
+ "` +
strings.ReplaceAll(catalog_name, "\"", "\"\"") +
`".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name =
TC.table_name
+ )`
+ }
+
+ query = `SELECT table_catalog, table_schema, table_name, table_type,
constraint_name, constraint_type FROM (` + query + `)`
conditions, queryArgs :=
prepareFilterConditions(adbc.ObjectDepthTables, catalog, dbSchema, tableName,
nil, tableType)
if conditions != "" {
query += " WHERE " + conditions
@@ -699,19 +1008,10 @@ func prepareColumnsSQL(matchingCatalogNames []string,
catalog *string, dbSchema
if prefixQuery != "" {
prefixQuery += " UNION ALL "
}
- prefixQuery += `SELECT T.table_type,
- C.*
- FROM
- "` + strings.ReplaceAll(catalogName, "\"",
"\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
- JOIN
- "` + strings.ReplaceAll(catalogName, "\"",
"\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
- ON
- T.table_catalog = C.table_catalog
- AND T.table_schema = C.table_schema
- AND t.table_name = C.table_name`
- }
-
- prefixQuery = `SELECT table_type, table_catalog, table_schema,
table_name, column_name,
+ prefixQuery += `SELECT * FROM "` +
strings.ReplaceAll(catalogName, "\"", "\"\"") + `".INFORMATION_SCHEMA.COLUMNS`
+ }
+
+ prefixQuery = `SELECT table_catalog, table_schema, table_name,
column_name,
ordinal_position,
is_nullable::boolean, data_type, numeric_precision,
numeric_precision_radix,
numeric_scale, is_identity::boolean,
identity_generation,
identity_increment,
@@ -757,7 +1057,7 @@ func prepareFilterConditions(depth adbc.ObjectDepth,
catalog *string, dbSchema *
}
var tblConditions []string
- if len(tableType) > 0 {
+ if len(tableType) > 0 && depth == adbc.ObjectDepthTables {
tblConditions = append(conditions, ` TABLE_TYPE IN
('`+strings.Join(tableType, `','`)+`')`)
} else {
tblConditions = conditions
@@ -767,6 +1067,18 @@ func prepareFilterConditions(depth adbc.ObjectDepth,
catalog *string, dbSchema *
return cond, queryArgs
}
+func prepareUniqueConstraintSQL() string {
+ return "SHOW UNIQUE KEYS"
+}
+
+func preparePrimaryKeyConstraintSQL() string {
+ return "SHOW PRIMARY KEYS"
+}
+
+func prepareForeignKeyConstraintSQL() string {
+ return "SHOW EXPORTED KEYS"
+}
+
func descToField(name, typ, isnull, primary string, comment sql.NullString)
(field arrow.Field, err error) {
field.Name = strings.ToLower(name)
if isnull == "Y" {
diff --git a/go/adbc/driver/snowflake/connection_test.go
b/go/adbc/driver/snowflake/connection_test.go
index 983a206d8..ae61749de 100644
--- a/go/adbc/driver/snowflake/connection_test.go
+++ b/go/adbc/driver/snowflake/connection_test.go
@@ -158,15 +158,38 @@ func TestPrepareTablesSQLWithNoFilter(t *testing.T) {
tableNamePattern := ""
tableType := make([]string, 0)
- expected := `SELECT table_catalog, table_schema, table_name, table_type
- FROM
- (
- SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.TABLES
- UNION ALL
- SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.TABLES
- UNION ALL
- SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.TABLES
- )`
+ expected := `SELECT table_catalog, table_schema, table_name,
table_type, constraint_name, constraint_type FROM (SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type,
+ TC.constraint_name, TC.constraint_type
+ FROM
+ (
+ "DEMO_DB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMO_DB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ ) UNION ALL SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMOADB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMOADB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ ) UNION ALL SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMO'DB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMO'DB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ ))`
actual, queryArgs := prepareTablesSQL(catalogNames[:], &catalogPattern,
&schemaPattern, &tableNamePattern, tableType[:])
println("Query Args", queryArgs)
@@ -180,16 +203,37 @@ func TestPrepareTablesSQLWithNoTableTypeFilter(t
*testing.T) {
tableNamePattern := "ADBC-TABLE"
tableType := make([]string, 0)
- expected := `SELECT table_catalog, table_schema, table_name, table_type
- FROM
- (
- SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.TABLES
- UNION ALL
- SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.TABLES
- UNION ALL
- SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.TABLES
- )
- WHERE TABLE_CATALOG ILIKE ? AND
TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? `
+ expected := `SELECT table_catalog, table_schema, table_name,
table_type, constraint_name, constraint_type FROM (SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMO_DB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMO_DB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ ) UNION ALL SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMOADB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMOADB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ ) UNION ALL SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMO'DB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMO'DB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ )) WHERE TABLE_CATALOG ILIKE ? AND
TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? `
actual, queryArgs := prepareTablesSQL(catalogNames[:], &catalogPattern,
&schemaPattern, &tableNamePattern, tableType[:])
stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
@@ -208,16 +252,37 @@ func TestPrepareTablesSQL(t *testing.T) {
tableNamePattern := "ADBC-TABLE"
tableType := [2]string{"BASE TABLE", "VIEW"}
- expected := `SELECT table_catalog, table_schema, table_name, table_type
- FROM
- (
- SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.TABLES
- UNION ALL
- SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.TABLES
- UNION ALL
- SELECT * FROM
"DEMO'DB".INFORMATION_SCHEMA.TABLES
- )
- WHERE TABLE_CATALOG ILIKE ? AND
TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? AND TABLE_TYPE IN ('BASE
TABLE','VIEW')`
+ expected := `SELECT table_catalog, table_schema, table_name,
table_type, constraint_name, constraint_type FROM (SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMO_DB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMO_DB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ ) UNION ALL SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMOADB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMOADB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ ) UNION ALL SELECT T.table_catalog,
T.table_schema, T.table_name, T.table_type, TC.constraint_name,
TC.constraint_type
+ FROM
+ (
+ "DEMO'DB".INFORMATION_SCHEMA.TABLES AS T
+ LEFT JOIN
+
"DEMO'DB".INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS TC
+ ON
+ T.table_catalog =
TC.table_catalog
+ AND T.table_schema =
TC.table_schema
+ AND t.table_name = TC.table_name
+ )) WHERE TABLE_CATALOG ILIKE ? AND
TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? AND TABLE_TYPE IN ('BASE
TABLE','VIEW')`
actual, queryArgs := prepareTablesSQL(catalogNames[:], &catalogPattern,
&schemaPattern, &tableNamePattern, tableType[:])
stringqueryArgs := make([]string, len(queryArgs)) // Pre-allocate the
right size
@@ -237,28 +302,16 @@ func TestPrepareColumnsSQLNoFilter(t *testing.T) {
columnNamePattern := ""
tableType := make([]string, 0)
- expected := `SELECT table_type, table_catalog, table_schema,
table_name, column_name,
+ expected := `SELECT table_catalog, table_schema, table_name,
column_name,
ordinal_position,
is_nullable::boolean, data_type, numeric_precision,
numeric_precision_radix,
numeric_scale, is_identity::boolean,
- identity_generation,
identity_increment,
- character_maximum_length,
character_octet_length, datetime_precision, comment
+ identity_generation,
identity_increment, character_maximum_length,
+ character_octet_length,
datetime_precision, comment
FROM
(
- SELECT T.table_type, C.*
- FROM
-
"DEMO_DB".INFORMATION_SCHEMA.TABLES AS T
- JOIN
-
"DEMO_DB".INFORMATION_SCHEMA.COLUMNS AS C
- ON
- T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.COLUMNS
UNION ALL
- SELECT T.table_type, C.*
- FROM
-
"DEMOADB".INFORMATION_SCHEMA.TABLES AS T
- JOIN
-
"DEMOADB".INFORMATION_SCHEMA.COLUMNS AS C
- ON
- T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.COLUMNS
)
ORDER BY table_catalog,
table_schema, table_name, ordinal_position`
actual, queryArgs := prepareColumnsSQL(catalogNames[:],
&catalogPattern, &schemaPattern, &tableNamePattern, &columnNamePattern,
tableType[:])
@@ -275,30 +328,18 @@ func TestPrepareColumnsSQL(t *testing.T) {
columnNamePattern := "creationDate"
tableType := [2]string{"BASE TABLE", "VIEW"}
- expected := `SELECT table_type, table_catalog, table_schema,
table_name, column_name,
+ expected := `SELECT table_catalog, table_schema, table_name,
column_name,
ordinal_position,
is_nullable::boolean, data_type, numeric_precision,
numeric_precision_radix,
numeric_scale, is_identity::boolean,
- identity_generation,
identity_increment,
- character_maximum_length,
character_octet_length, datetime_precision, comment
+ identity_generation,
identity_increment, character_maximum_length,
+ character_octet_length,
datetime_precision, comment
FROM
(
- SELECT T.table_type, C.*
- FROM
-
"DEMO_DB".INFORMATION_SCHEMA.TABLES AS T
- JOIN
-
"DEMO_DB".INFORMATION_SCHEMA.COLUMNS AS C
- ON
- T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ SELECT * FROM
"DEMO_DB".INFORMATION_SCHEMA.COLUMNS
UNION ALL
- SELECT T.table_type, C.*
- FROM
-
"DEMOADB".INFORMATION_SCHEMA.TABLES AS T
- JOIN
-
"DEMOADB".INFORMATION_SCHEMA.COLUMNS AS C
- ON
- T.table_catalog
= C.table_catalog AND T.table_schema = C.table_schema AND t.table_name =
C.table_name
+ SELECT * FROM
"DEMOADB".INFORMATION_SCHEMA.COLUMNS
)
- WHERE TABLE_CATALOG ILIKE ?
AND TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? AND COLUMN_NAME ILIKE ? AND
TABLE_TYPE IN ('BASE TABLE','VIEW')
+ WHERE TABLE_CATALOG ILIKE ?
AND TABLE_SCHEMA ILIKE ? AND TABLE_NAME ILIKE ? AND COLUMN_NAME ILIKE ?
ORDER BY table_catalog,
table_schema, table_name, ordinal_position`
actual, queryArgs := prepareColumnsSQL(catalogNames[:],
&catalogPattern, &schemaPattern, &tableNamePattern, &columnNamePattern,
tableType[:])