This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 0a0ee3ef71 GH-47786: [C++][FlightRPC] Establish ODBC tests (#47788)
0a0ee3ef71 is described below
commit 0a0ee3ef71c20b2d35c92e72fe0bb878f2f14a56
Author: justing-bq <[email protected]>
AuthorDate: Tue Oct 21 20:09:25 2025 -0700
GH-47786: [C++][FlightRPC] Establish ODBC tests (#47788)
### Rationale for this change
https://github.com/apache/arrow/issues/47786
### What changes are included in this PR?
Created a new directory/subproject for tests.
Added a testing utility class called `odbc_test_suite`.
Created new test file with one simple test.
### Are these changes tested?
Yes
### Are there any user-facing changes?
No
* GitHub Issue: #47786
Authored-by: justing-bq <[email protected]>
Signed-off-by: David Li <[email protected]>
---
cpp/src/arrow/flight/sql/odbc/CMakeLists.txt | 1 +
.../flight/sql/odbc/odbc_impl/encoding_utils.h | 21 +
.../arrow/flight/sql/odbc/odbc_impl/system_dsn.cc | 18 +-
.../arrow/flight/sql/odbc/odbc_impl/system_dsn.h | 68 +++
cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt | 46 ++
.../arrow/flight/sql/odbc/tests/connection_test.cc | 43 ++
.../arrow/flight/sql/odbc/tests/odbc_test_suite.cc | 504 +++++++++++++++++++++
.../arrow/flight/sql/odbc/tests/odbc_test_suite.h | 254 +++++++++++
8 files changed, 948 insertions(+), 7 deletions(-)
diff --git a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt
b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt
index 79f351ae9c..ac18a9bc7c 100644
--- a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt
+++ b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt
@@ -34,6 +34,7 @@ else()
endif()
add_subdirectory(odbc_impl)
+add_subdirectory(tests)
arrow_install_all_headers("arrow/flight/sql/odbc")
diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h
b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h
index b3c1030f40..a5cc3a6f4c 100644
--- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h
+++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h
@@ -36,6 +36,7 @@ namespace ODBC {
using arrow::flight::sql::odbc::DriverException;
using arrow::flight::sql::odbc::GetSqlWCharSize;
using arrow::flight::sql::odbc::Utf8ToWcs;
+using arrow::flight::sql::odbc::WcsToUtf8;
// Return the number of bytes required for the conversion.
template <typename CHAR_TYPE>
@@ -80,4 +81,24 @@ inline size_t ConvertToSqlWChar(const std::string& str,
SQLWCHAR* buffer,
}
}
+/// \brief Convert buffer of SqlWchar to standard string
+/// \param[in] wchar_msg SqlWchar to convert
+/// \param[in] msg_len Number of characters in wchar_msg
+/// \return wchar_msg in std::string format
+inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len =
SQL_NTS) {
+ if (msg_len == 0 || !wchar_msg || wchar_msg[0] == 0) {
+ return std::string();
+ }
+
+ thread_local std::vector<uint8_t> utf8_str;
+
+ if (msg_len == SQL_NTS) {
+ WcsToUtf8((void*)wchar_msg, &utf8_str);
+ } else {
+ WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str);
+ }
+
+ return std::string(utf8_str.begin(), utf8_str.end());
+}
+
} // namespace ODBC
diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc
b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc
index d10eff2580..75501ac8dd 100644
--- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc
+++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+#include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h"
+
// platform.h includes windows.h, so it needs to be included
// before winuser.h
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
@@ -33,13 +35,13 @@
#include <odbcinst.h>
#include <sstream>
-using arrow::flight::sql::odbc::DriverException;
-using arrow::flight::sql::odbc::FlightSqlConnection;
-using arrow::flight::sql::odbc::config::Configuration;
-using arrow::flight::sql::odbc::config::ConnectionStringParser;
-using arrow::flight::sql::odbc::config::DsnConfigurationWindow;
-using arrow::flight::sql::odbc::config::Result;
-using arrow::flight::sql::odbc::config::Window;
+namespace arrow::flight::sql::odbc {
+
+using config::Configuration;
+using config::ConnectionStringParser;
+using config::DsnConfigurationWindow;
+using config::Result;
+using config::Window;
bool DisplayConnectionWindow(void* window_parent, Configuration& config) {
HWND hwnd_parent = (HWND)window_parent;
@@ -237,3 +239,5 @@ BOOL INSTAPI ConfigDSNW(HWND hwnd_parent, WORD req, LPCWSTR
wdriver,
return TRUE;
}
+
+} // namespace arrow::flight::sql::odbc
diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h
b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h
new file mode 100644
index 0000000000..32d17af675
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// platform.h includes windows.h, so it needs to be included first
+#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
+
+#include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h"
+
+namespace arrow::flight::sql::odbc {
+
+using config::Configuration;
+
+#if defined _WIN32
+/**
+ * Display connection window for user to configure connection parameters.
+ *
+ * @param window_parent Parent window handle.
+ * @param config Output configuration.
+ * @return True on success and false on fail.
+ */
+bool DisplayConnectionWindow(void* window_parent, Configuration& config);
+
+/**
+ * For SQLDriverConnect.
+ * Display connection window for user to configure connection parameters.
+ *
+ * @param window_parent Parent window handle.
+ * @param config Output configuration, presumed to be empty, it will be using
values from
+ * properties.
+ * @param properties Output properties.
+ * @return True on success and false on fail.
+ */
+bool DisplayConnectionWindow(void* window_parent, Configuration& config,
+ Connection::ConnPropertyMap& properties);
+#endif
+
+/**
+ * Register DSN with specified configuration.
+ *
+ * @param config Configuration.
+ * @param driver Driver.
+ * @return True on success and false on fail.
+ */
+bool RegisterDsn(const Configuration& config, LPCWSTR driver);
+
+/**
+ * Unregister specified DSN.
+ *
+ * @param dsn DSN name.
+ * @return True on success and false on fail.
+ */
+bool UnregisterDsn(const std::wstring& dsn);
+
+} // namespace arrow::flight::sql::odbc
diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt
b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt
new file mode 100644
index 0000000000..4bc240637e
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(tests)
+
+find_package(ODBC REQUIRED)
+include_directories(${ODBC_INCLUDE_DIRS})
+
+find_package(SQLite3Alt REQUIRED)
+
+set(ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS
+ ../../example/sqlite_sql_info.cc
+ ../../example/sqlite_type_info.cc
+ ../../example/sqlite_statement.cc
+ ../../example/sqlite_statement_batch_reader.cc
+ ../../example/sqlite_server.cc
+ ../../example/sqlite_tables_schema_batch_reader.cc)
+
+add_arrow_test(flight_sql_odbc_test
+ SOURCES
+ odbc_test_suite.cc
+ odbc_test_suite.h
+ connection_test.cc
+ # Enable Protobuf cleanup after test execution
+ # GH-46889: move protobuf_test_util to a more common location
+ ../../../../engine/substrait/protobuf_test_util.cc
+ ${ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS}
+ EXTRA_LINK_LIBS
+ ${ODBC_LIBRARIES}
+ ${ODBCINST}
+ ${SQLite3_LIBRARIES}
+ arrow_odbc_spi_impl)
diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc
b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc
new file mode 100644
index 0000000000..fa1ccf2854
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifdef _WIN32
+# include <windows.h>
+#endif
+
+#include <sql.h>
+#include <sqltypes.h>
+#include <sqlucode.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace arrow::flight::sql::odbc {
+
+TEST(SQLAllocHandle, SQLAllocHandleEnv) {
+ // Allocate an environment handle
+ SQLHENV env = nullptr;
+ ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE,
&env));
+
+ // Check for valid handle
+ ASSERT_NE(nullptr, env);
+
+ // Free an environment handle
+ ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env));
+}
+
+} // namespace arrow::flight::sql::odbc
diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc
b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc
new file mode 100644
index 0000000000..fccb552575
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc
@@ -0,0 +1,504 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// For DSN registration. flight_sql_connection.h needs to included first due
to conflicts
+// with windows.h
+#include "arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h"
+
+#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h"
+
+// For DSN registration
+#include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h"
+#include "arrow/flight/sql/odbc/odbc_impl/encoding_utils.h"
+#include "arrow/flight/sql/odbc/odbc_impl/odbc_connection.h"
+
+namespace arrow::flight::sql::odbc {
+
+void FlightSQLODBCRemoteTestBase::AllocEnvConnHandles(SQLINTEGER odbc_ver) {
+ // Allocate an environment handle
+ ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env));
+
+ ASSERT_EQ(
+ SQL_SUCCESS,
+ SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION,
+
reinterpret_cast<SQLPOINTER>(static_cast<intptr_t>(odbc_ver)), 0));
+
+ // Allocate a connection using alloc handle
+ ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn));
+}
+
+void FlightSQLODBCRemoteTestBase::Connect(SQLINTEGER odbc_ver) {
+ ASSERT_NO_FATAL_FAILURE(AllocEnvConnHandles(odbc_ver));
+ std::string connect_str = GetConnectionString();
+ ASSERT_NO_FATAL_FAILURE(ConnectWithString(connect_str));
+}
+
+void FlightSQLODBCRemoteTestBase::ConnectWithString(std::string connect_str) {
+ // Connect string
+ std::vector<SQLWCHAR> connect_str0(connect_str.begin(), connect_str.end());
+
+ SQLWCHAR out_str[kOdbcBufferSize];
+ SQLSMALLINT out_str_len;
+
+ // Connecting to ODBC server.
+ ASSERT_EQ(SQL_SUCCESS,
+ SQLDriverConnect(conn, NULL, &connect_str0[0],
+ static_cast<SQLSMALLINT>(connect_str0.size()),
out_str,
+ kOdbcBufferSize, &out_str_len,
SQL_DRIVER_NOPROMPT))
+ << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn);
+
+ // Allocate a statement using alloc handle
+ ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt));
+}
+
+void FlightSQLODBCRemoteTestBase::Disconnect() {
+ // Close statement
+ EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt));
+
+ // Disconnect from ODBC
+ EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(conn))
+ << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn);
+
+ // Free connection handle
+ EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn));
+
+ // Free environment handle
+ EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env));
+}
+
+std::string FlightSQLODBCRemoteTestBase::GetConnectionString() {
+ std::string connect_str =
+ arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOrDie();
+ return connect_str;
+}
+
+std::string FlightSQLODBCRemoteTestBase::GetInvalidConnectionString() {
+ std::string connect_str = GetConnectionString();
+ // Append invalid uid to connection string
+ connect_str += std::string("uid=non_existent_id;");
+ return connect_str;
+}
+
+std::wstring FlightSQLODBCRemoteTestBase::GetQueryAllDataTypes() {
+ std::wstring wsql =
+ LR"( SELECT
+ -- Numeric types
+ -128 as stiny_int_min, 127 as stiny_int_max,
+ 0 as utiny_int_min, 255 as utiny_int_max,
+
+ -32768 as ssmall_int_min, 32767 as ssmall_int_max,
+ 0 as usmall_int_min, 65535 as usmall_int_max,
+
+ CAST(-2147483648 AS INTEGER) AS sinteger_min,
+ CAST(2147483647 AS INTEGER) AS sinteger_max,
+ CAST(0 AS BIGINT) AS uinteger_min,
+ CAST(4294967295 AS BIGINT) AS uinteger_max,
+
+ CAST(-9223372036854775808 AS BIGINT) AS sbigint_min,
+ CAST(9223372036854775807 AS BIGINT) AS sbigint_max,
+ CAST(0 AS BIGINT) AS ubigint_min,
+ --Use string to represent unsigned big int due to lack of support
from
+ --remote test server
+ '18446744073709551615' AS ubigint_max,
+
+ CAST(-999999999 AS DECIMAL(38, 0)) AS decimal_negative,
+ CAST(999999999 AS DECIMAL(38, 0)) AS decimal_positive,
+
+ CAST(-3.40282347E38 AS FLOAT) AS float_min, CAST(3.40282347E38 AS
FLOAT) AS float_max,
+
+ CAST(-1.7976931348623157E308 AS DOUBLE) AS double_min,
+ CAST(1.7976931348623157E308 AS DOUBLE) AS double_max,
+
+ --Boolean
+ CAST(false AS BOOLEAN) AS bit_false,
+ CAST(true AS BOOLEAN) AS bit_true,
+
+ --Character types
+ 'Z' AS c_char, '你' AS c_wchar,
+
+ '你好' AS c_wvarchar,
+
+ 'XYZ' AS c_varchar,
+
+ --Date / timestamp
+ CAST(DATE '1400-01-01' AS DATE) AS date_min,
+ CAST(DATE '9999-12-31' AS DATE) AS date_max,
+
+ CAST(TIMESTAMP '1400-01-01 00:00:00' AS TIMESTAMP) AS timestamp_min,
+ CAST(TIMESTAMP '9999-12-31 23:59:59' AS TIMESTAMP) AS timestamp_max;
+ )";
+ return wsql;
+}
+
+void FlightSQLODBCRemoteTestBase::SetUp() {
+ if (arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr("").empty()) {
+ GTEST_SKIP() << "Skipping test: kTestConnectStr not set";
+ }
+
+ this->Connect();
+ connected_ = true;
+}
+
+void FlightSQLODBCRemoteTestBase::TearDown() {
+ if (connected_) {
+ this->Disconnect();
+ connected_ = false;
+ }
+}
+
+void FlightSQLOdbcV2RemoteTestBase::SetUp() {
+ if (arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr("").empty()) {
+ GTEST_SKIP() << "Skipping test: kTestConnectStr not set";
+ }
+
+ this->Connect(SQL_OV_ODBC2);
+ connected_ = true;
+}
+
+std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers) {
+ // Lambda function to compare characters without case sensitivity.
+ auto char_compare = [](const char& char1, const char& char2) {
+ return (::toupper(char1) == ::toupper(char2));
+ };
+
+ std::string bearer_token("");
+ auto auth_header = incoming_headers.find(kAuthorizationHeader);
+ if (auth_header != incoming_headers.end()) {
+ const std::string auth_val(auth_header->second);
+ if (auth_val.size() > kBearerPrefix.length()) {
+ if (std::equal(auth_val.begin(), auth_val.begin() +
kBearerPrefix.length(),
+ kBearerPrefix.begin(), char_compare)) {
+ bearer_token = auth_val.substr(kBearerPrefix.length());
+ }
+ }
+ }
+ return bearer_token;
+}
+
+void MockServerMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) {
+ std::string bearer_token = FindTokenInCallHeaders(incoming_headers_);
+ *is_valid_ = (bearer_token == std::string(kTestToken));
+}
+
+Status MockServerMiddlewareFactory::StartCall(
+ const CallInfo& info, const ServerCallContext& context,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string bearer_token =
FindTokenInCallHeaders(context.incoming_headers());
+ if (bearer_token == std::string(kTestToken)) {
+ *middleware =
+ std::make_shared<MockServerMiddleware>(context.incoming_headers(),
&is_valid_);
+ } else {
+ return MakeFlightError(FlightStatusCode::Unauthenticated,
+ "Invalid token for mock server");
+ }
+
+ return Status::OK();
+}
+
+std::string FlightSQLODBCMockTestBase::GetConnectionString() {
+ std::string connect_str(
+ "driver={Apache Arrow Flight SQL ODBC Driver};HOST=localhost;port=" +
+ std::to_string(port) + ";token=" + std::string(kTestToken) +
+ ";useEncryption=false;");
+ return connect_str;
+}
+
+std::string FlightSQLODBCMockTestBase::GetInvalidConnectionString() {
+ std::string connect_str = GetConnectionString();
+ // Append invalid token to connection string
+ connect_str += std::string("token=invalid_token;");
+ return connect_str;
+}
+
+std::wstring FlightSQLODBCMockTestBase::GetQueryAllDataTypes() {
+ std::wstring wsql =
+ LR"( SELECT
+ -- Numeric types
+ -128 AS stiny_int_min, 127 AS stiny_int_max,
+ 0 AS utiny_int_min, 255 AS utiny_int_max,
+
+ -32768 AS ssmall_int_min, 32767 AS ssmall_int_max,
+ 0 AS usmall_int_min, 65535 AS usmall_int_max,
+
+ CAST(-2147483648 AS INTEGER) AS sinteger_min,
+ CAST(2147483647 AS INTEGER) AS sinteger_max,
+ CAST(0 AS INTEGER) AS uinteger_min,
+ CAST(4294967295 AS INTEGER) AS uinteger_max,
+
+ CAST(-9223372036854775808 AS INTEGER) AS sbigint_min,
+ CAST(9223372036854775807 AS INTEGER) AS sbigint_max,
+ CAST(0 AS INTEGER) AS ubigint_min,
+ -- stored as TEXT as SQLite doesn't support unsigned big int
+ '18446744073709551615' AS ubigint_max,
+
+ CAST('-999999999' AS NUMERIC) AS decimal_negative,
+ CAST('999999999' AS NUMERIC) AS decimal_positive,
+
+ CAST(-3.40282347E38 AS REAL) AS float_min,
+ CAST(3.40282347E38 AS REAL) AS float_max,
+
+ CAST(-1.7976931348623157E308 AS REAL) AS double_min,
+ CAST(1.7976931348623157E308 AS REAL) AS double_max,
+
+ -- Boolean
+ 0 AS bit_false,
+ 1 AS bit_true,
+
+ -- Character types
+ 'Z' AS c_char,
+ '你' AS c_wchar,
+ '你好' AS c_wvarchar,
+ 'XYZ' AS c_varchar,
+
+ DATE('1400-01-01') AS date_min,
+ DATE('9999-12-31') AS date_max,
+
+ DATETIME('1400-01-01 00:00:00') AS timestamp_min,
+ DATETIME('9999-12-31 23:59:59') AS timestamp_max;
+ )";
+ return wsql;
+}
+
+void FlightSQLODBCMockTestBase::CreateTestTables() {
+ ASSERT_OK(server_->ExecuteSql(R"(
+ CREATE TABLE TestTable (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ keyName varchar(100),
+ value int);
+
+ INSERT INTO TestTable (keyName, value) VALUES ('One', 1);
+ INSERT INTO TestTable (keyName, value) VALUES ('Two', 0);
+ INSERT INTO TestTable (keyName, value) VALUES ('Three', -1);
+ )"));
+}
+
+void FlightSQLODBCMockTestBase::CreateTableAllDataType() {
+ // Limitation on mock SQLite server:
+ // Only int64, float64, binary, and utf8 Arrow Types are supported by
+ // SQLiteFlightSqlServer::Impl::DoGetTables
+ ASSERT_OK(server_->ExecuteSql(R"(
+ CREATE TABLE AllTypesTable(
+ bigint_col INTEGER PRIMARY KEY AUTOINCREMENT,
+ char_col varchar(100),
+ varbinary_col BLOB,
+ double_col REAL);
+
+ INSERT INTO AllTypesTable (
+ char_col,
+ varbinary_col,
+ double_col) VALUES (
+ '1st Row',
+ X'31737420726F77',
+ 3.14159
+ );
+ )"));
+}
+
+void FlightSQLODBCMockTestBase::CreateUnicodeTable() {
+ std::string unicode_sql = arrow::util::WideStringToUTF8(
+ LR"(
+ CREATE TABLE 数据(
+ 资料 varchar(100));
+
+ INSERT INTO 数据 (资料) VALUES ('第一行');
+ INSERT INTO 数据 (资料) VALUES ('二行');
+ INSERT INTO 数据 (资料) VALUES ('3rd Row');
+ )")
+ .ValueOr("");
+ ASSERT_OK(server_->ExecuteSql(unicode_sql));
+}
+
+void FlightSQLODBCMockTestBase::Initialize() {
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("0.0.0.0", 0));
+ arrow::flight::FlightServerOptions options(location);
+ options.auth_handler = std::make_unique<NoOpAuthHandler>();
+ options.middleware.push_back(
+ {"bearer-auth-server", std::make_shared<MockServerMiddlewareFactory>()});
+ ASSERT_OK_AND_ASSIGN(server_,
+
arrow::flight::sql::example::SQLiteFlightSqlServer::Create());
+ ASSERT_OK(server_->Init(options));
+
+ port = server_->port();
+ ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", port));
+ ASSERT_OK_AND_ASSIGN(auto client,
arrow::flight::FlightClient::Connect(location));
+}
+
+void FlightSQLODBCMockTestBase::SetUp() {
+ this->Initialize();
+ this->Connect();
+ connected_ = true;
+}
+
+void FlightSQLODBCMockTestBase::TearDown() {
+ if (connected_) {
+ this->Disconnect();
+ connected_ = false;
+ }
+ ASSERT_OK(server_->Shutdown());
+}
+
+void FlightSQLOdbcV2MockTestBase::SetUp() {
+ this->Initialize();
+ this->Connect(SQL_OV_ODBC2);
+ connected_ = true;
+}
+
+bool CompareConnPropertyMap(Connection::ConnPropertyMap map1,
+ Connection::ConnPropertyMap map2) {
+ if (map1.size() != map2.size()) return false;
+
+ for (const auto& [key, value] : map1) {
+ if (value != map2[key]) return false;
+ }
+
+ return true;
+}
+
+void VerifyOdbcErrorState(SQLSMALLINT handle_type, SQLHANDLE handle,
+ std::string_view expected_state) {
+ using ODBC::SqlWcharToString;
+
+ SQLWCHAR sql_state[7] = {};
+ SQLINTEGER native_code;
+
+ SQLWCHAR message[kOdbcBufferSize] = {};
+ SQLSMALLINT real_len = 0;
+
+ // On Windows, real_len is in bytes. On Linux, real_len is in chars.
+ // So, not using real_len
+ SQLGetDiagRec(handle_type, handle, 1, sql_state, &native_code, message,
kOdbcBufferSize,
+ &real_len);
+
+ EXPECT_EQ(expected_state, SqlWcharToString(sql_state));
+}
+
+std::string GetOdbcErrorMessage(SQLSMALLINT handle_type, SQLHANDLE handle) {
+ using ODBC::SqlWcharToString;
+
+ SQLWCHAR sql_state[7] = {};
+ SQLINTEGER native_code;
+
+ SQLWCHAR message[kOdbcBufferSize] = {};
+ SQLSMALLINT real_len = 0;
+
+ // On Windows, real_len is in bytes. On Linux, real_len is in chars.
+ // So, not using real_len
+ SQLGetDiagRec(handle_type, handle, 1, sql_state, &native_code, message,
kOdbcBufferSize,
+ &real_len);
+
+ std::string res = SqlWcharToString(sql_state);
+
+ if (res.empty() || !message[0]) {
+ res = "Cannot find ODBC error message";
+ } else {
+ res.append(": ").append(SqlWcharToString(message));
+ }
+
+ return res;
+}
+
+// TODO: once RegisterDsn is implemented in Mac and Linux, the following can be
+// re-enabled.
+#if defined _WIN32
+bool WriteDSN(std::string connection_str) {
+ Connection::ConnPropertyMap properties;
+
+ ODBC::ODBCConnection::GetPropertiesFromConnString(connection_str,
properties);
+ return WriteDSN(properties);
+}
+
+bool WriteDSN(Connection::ConnPropertyMap properties) {
+ using arrow::flight::sql::odbc::Connection;
+ using arrow::flight::sql::odbc::FlightSqlConnection;
+ using arrow::flight::sql::odbc::config::Configuration;
+ using ODBC::ODBCConnection;
+
+ Configuration config;
+ config.Set(FlightSqlConnection::DSN, std::string(kTestDsn));
+
+ for (const auto& [key, value] : properties) {
+ config.Set(key, value);
+ }
+
+ std::string driver = config.Get(FlightSqlConnection::DRIVER);
+ std::wstring w_driver = arrow::util::UTF8ToWideString(driver).ValueOr(L"");
+ return RegisterDsn(config, w_driver.c_str());
+}
+#endif
+
+std::wstring ConvertToWString(const std::vector<SQLWCHAR>& str_val,
SQLSMALLINT str_len) {
+ std::wstring attr_str;
+ if (str_len == 0) {
+ attr_str = std::wstring(&str_val[0]);
+ } else {
+ EXPECT_GT(str_len, 0);
+ EXPECT_LE(str_len, static_cast<SQLSMALLINT>(kOdbcBufferSize));
+ attr_str = std::wstring(str_val.begin(),
+ str_val.begin() + str_len /
ODBC::GetSqlWCharSize());
+ }
+ return attr_str;
+}
+
+void CheckStringColumnW(SQLHSTMT stmt, int col_id, const std::wstring&
expected) {
+ SQLWCHAR buf[1024];
+ SQLLEN buf_len = sizeof(buf) * ODBC::GetSqlWCharSize();
+
+ ASSERT_EQ(SQL_SUCCESS, SQLGetData(stmt, col_id, SQL_C_WCHAR, buf, buf_len,
&buf_len));
+
+ EXPECT_GT(buf_len, 0);
+
+ // returned buf_len is in bytes so convert to length in characters
+ size_t char_count = static_cast<size_t>(buf_len) / ODBC::GetSqlWCharSize();
+ std::wstring returned(buf, buf + char_count);
+
+ EXPECT_EQ(expected, returned);
+}
+
+void CheckNullColumnW(SQLHSTMT stmt, int col_id) {
+ SQLWCHAR buf[1024];
+ SQLLEN buf_len = sizeof(buf);
+
+ ASSERT_EQ(SQL_SUCCESS, SQLGetData(stmt, col_id, SQL_C_WCHAR, buf, buf_len,
&buf_len));
+
+ EXPECT_EQ(SQL_NULL_DATA, buf_len);
+}
+
+void CheckIntColumn(SQLHSTMT stmt, int col_id, const SQLINTEGER& expected) {
+ SQLINTEGER buf;
+ SQLLEN buf_len = sizeof(buf);
+
+ ASSERT_EQ(SQL_SUCCESS,
+ SQLGetData(stmt, col_id, SQL_C_LONG, &buf, sizeof(buf), &buf_len));
+
+ EXPECT_EQ(expected, buf);
+}
+
+void CheckSmallIntColumn(SQLHSTMT stmt, int col_id, const SQLSMALLINT&
expected) {
+ SQLSMALLINT buf;
+ SQLLEN buf_len = sizeof(buf);
+
+ ASSERT_EQ(SQL_SUCCESS,
+ SQLGetData(stmt, col_id, SQL_C_SSHORT, &buf, sizeof(buf),
&buf_len));
+
+ EXPECT_EQ(expected, buf);
+}
+
+void ValidateFetch(SQLHSTMT stmt, SQLRETURN expected_return) {
+ ASSERT_EQ(expected_return, SQLFetch(stmt));
+}
+
+} // namespace arrow::flight::sql::odbc
diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h
b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h
new file mode 100644
index 0000000000..e35e6c38f8
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h
@@ -0,0 +1,254 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/utf8.h"
+
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/sql/client.h"
+#include "arrow/flight/sql/example/sqlite_server.h"
+#include "arrow/flight/sql/odbc/odbc_impl/encoding_utils.h"
+#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
+
+#include <sql.h>
+#include <sqltypes.h>
+#include <sqlucode.h>
+
+#include <type_traits>
+
+#include "arrow/flight/sql/odbc/odbc_impl/odbc_connection.h"
+
+// For DSN registration
+#include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h"
+
+static constexpr std::string_view kTestConnectStr =
"ARROW_FLIGHT_SQL_ODBC_CONN";
+static constexpr std::string_view kTestDsn = "Apache Arrow Flight SQL Test
DSN";
+
+namespace arrow::flight::sql::odbc {
+
+/// \brief Base test fixture for running tests against a remote server.
+/// Each test file running remote server tests should define a
+/// fixture inheriting from this base fixture.
+/// The connection string for connecting to this server is defined
+/// in the ARROW_FLIGHT_SQL_ODBC_CONN environment variable.
+class FlightSQLODBCRemoteTestBase : public ::testing::Test {
+ public:
+ /// \brief Allocate environment and connection handles
+ void AllocEnvConnHandles(SQLINTEGER odbc_ver = SQL_OV_ODBC3);
+ /// \brief Connect to Arrow Flight SQL server using connection string
defined in
+ /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN", allocate statement
handle.
+ /// Connects using ODBC Ver 3 by default
+ void Connect(SQLINTEGER odbc_ver = SQL_OV_ODBC3);
+ /// \brief Connect to Arrow Flight SQL server using connection string
+ void ConnectWithString(std::string connection_str);
+ /// \brief Disconnect from server
+ void Disconnect();
+ /// \brief Get connection string from environment variable
"ARROW_FLIGHT_SQL_ODBC_CONN"
+ std::string virtual GetConnectionString();
+ /// \brief Get invalid connection string based on connection string defined
in
+ /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN"
+ std::string virtual GetInvalidConnectionString();
+ /// \brief Return a SQL query that selects all data types
+ std::wstring virtual GetQueryAllDataTypes();
+
+ /** ODBC Environment. */
+ SQLHENV env = 0;
+
+ /** ODBC Connect. */
+ SQLHDBC conn = 0;
+
+ /** ODBC Statement. */
+ SQLHSTMT stmt = 0;
+
+ protected:
+ void SetUp() override;
+
+ void TearDown() override;
+
+ bool connected_ = false;
+};
+
+/// \brief Base test fixture for running ODBC V2 tests against a remote server.
+/// Each test file running remote server ODBC V2 tests should define a
+/// fixture inheriting from this base fixture.
+class FlightSQLOdbcV2RemoteTestBase : public FlightSQLODBCRemoteTestBase {
+ protected:
+ void SetUp() override;
+};
+
+static constexpr std::string_view kAuthorizationHeader = "authorization";
+static constexpr std::string_view kBearerPrefix = "Bearer ";
+static constexpr std::string_view kTestToken = "t0k3n";
+
+std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers);
+
+// A server middleware for validating incoming bearer header authentication.
+class MockServerMiddleware : public ServerMiddleware {
+ public:
+ explicit MockServerMiddleware(const CallHeaders& incoming_headers, bool*
is_valid)
+ : is_valid_(is_valid) {
+ incoming_headers_ = incoming_headers;
+ }
+
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override;
+
+ void CallCompleted(const Status& status) override {}
+
+ std::string name() const override { return "MockServerMiddleware"; }
+
+ private:
+ CallHeaders incoming_headers_;
+ bool* is_valid_;
+};
+
+// Factory for base64 header authentication testing.
+class MockServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ MockServerMiddlewareFactory() : is_valid_(false) {}
+
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
+ std::shared_ptr<ServerMiddleware>* middleware) override;
+
+ private:
+ bool is_valid_;
+};
+
+/// \brief Base test fixture for running tests against a mock server.
+/// Each test file running mock server tests should define a
+/// fixture inheriting from this base fixture.
+class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase {
+ // Sets up a mock server for each test case
+ public:
+ /// \brief Get connection string for mock server
+ std::string GetConnectionString() override;
+ /// \brief Get invalid connection string for mock server
+ std::string GetInvalidConnectionString() override;
+ /// \brief Return a SQL query that selects all data types
+ std::wstring GetQueryAllDataTypes() override;
+
+ /// \brief Run a SQL query to create default table for table test cases
+ void CreateTestTables();
+
+ /// \brief run a SQL query to create a table with all data types
+ void CreateTableAllDataType();
+ /// \brief run a SQL query to create a table with unicode name
+ void CreateUnicodeTable();
+
+ int port;
+
+ protected:
+ void Initialize();
+
+ void SetUp() override;
+
+ void TearDown() override;
+
+ private:
+ std::shared_ptr<arrow::flight::sql::example::SQLiteFlightSqlServer> server_;
+};
+
+/// \brief Base test fixture for running ODBC V2 tests against a mock server.
+/// Each test file running mock server ODBC V2 tests should define a
+/// fixture inheriting from this base fixture.
+class FlightSQLOdbcV2MockTestBase : public FlightSQLODBCMockTestBase {
+ protected:
+ void SetUp() override;
+};
+
+/** ODBC read buffer size. */
+static constexpr int kOdbcBufferSize = 1024;
+
+/// Compare ConnPropertyMap, key value is case-insensitive
+bool CompareConnPropertyMap(Connection::ConnPropertyMap map1,
+ Connection::ConnPropertyMap map2);
+
+/// Get error message from ODBC driver using SQLGetDiagRec
+std::string GetOdbcErrorMessage(SQLSMALLINT handle_type, SQLHANDLE handle);
+
+static constexpr std::string_view kErrorState01004 = "01004";
+static constexpr std::string_view kErrorState01S07 = "01S07";
+static constexpr std::string_view kErrorState01S02 = "01S02";
+static constexpr std::string_view kErrorState07009 = "07009";
+static constexpr std::string_view kErrorState08003 = "08003";
+static constexpr std::string_view kErrorState22002 = "22002";
+static constexpr std::string_view kErrorState24000 = "24000";
+static constexpr std::string_view kErrorState28000 = "28000";
+static constexpr std::string_view kErrorStateHY000 = "HY000";
+static constexpr std::string_view kErrorStateHY004 = "HY004";
+static constexpr std::string_view kErrorStateHY009 = "HY009";
+static constexpr std::string_view kErrorStateHY010 = "HY010";
+static constexpr std::string_view kErrorStateHY017 = "HY017";
+static constexpr std::string_view kErrorStateHY024 = "HY024";
+static constexpr std::string_view kErrorStateHY090 = "HY090";
+static constexpr std::string_view kErrorStateHY091 = "HY091";
+static constexpr std::string_view kErrorStateHY092 = "HY092";
+static constexpr std::string_view kErrorStateHY106 = "HY106";
+static constexpr std::string_view kErrorStateHY114 = "HY114";
+static constexpr std::string_view kErrorStateHY118 = "HY118";
+static constexpr std::string_view kErrorStateHYC00 = "HYC00";
+static constexpr std::string_view kErrorStateS1004 = "S1004";
+
+/// Verify ODBC Error State
+void VerifyOdbcErrorState(SQLSMALLINT handle_type, SQLHANDLE handle,
+ std::string_view expected_state);
+
+/// \brief Write connection string into DSN
+/// \param[in] connection_str the connection string.
+/// \return true on success
+bool WriteDSN(std::string connection_str);
+
+/// \brief Write properties map into DSN
+/// \param[in] properties map.
+/// \return true on success
+bool WriteDSN(Connection::ConnPropertyMap properties);
+
+/// \brief Check wide char vector and convert into wstring
+/// \param[in] str_val Vector of SQLWCHAR.
+/// \param[in] str_len length of string, in bytes.
+/// \return wstring
+std::wstring ConvertToWString(const std::vector<SQLWCHAR>& str_val,
SQLSMALLINT str_len);
+
+/// \brief Check wide string column.
+/// \param[in] stmt Statement.
+/// \param[in] col_id Column ID to check.
+/// \param[in] expected Expected value.
+void CheckStringColumnW(SQLHSTMT stmt, int col_id, const std::wstring&
expected);
+
+/// \brief Check wide string column value is null.
+/// \param[in] stmt Statement.
+/// \param[in] col_id Column ID to check.
+void CheckNullColumnW(SQLHSTMT stmt, int col_id);
+
+/// \brief Check int column.
+/// \param[in] stmt Statement.
+/// \param[in] col_id Column ID to check.
+/// \param[in] expected Expected value.
+void CheckIntColumn(SQLHSTMT stmt, int col_id, const SQLINTEGER& expected);
+
+/// \brief Check smallint column.
+/// \param[in] stmt Statement.
+/// \param[in] col_id Column ID to check.
+/// \param[in] expected Expected value.
+void CheckSmallIntColumn(SQLHSTMT stmt, int col_id, const SQLSMALLINT&
expected);
+
+/// \brief Check sql return against expected.
+/// \param[in] stmt Statement.
+/// \param[in] expected Expected return.
+void ValidateFetch(SQLHSTMT stmt, SQLRETURN expected);
+
+} // namespace arrow::flight::sql::odbc