justing-bq commented on code in PR #47788: URL: https://github.com/apache/arrow/pull/47788#discussion_r2449186109
########## cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h: ########## @@ -0,0 +1,234 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/io_util.h" +#include "arrow/util/utf8.h" + +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/flight/sql/odbc/odbc_impl/encoding_utils.h" +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include <sql.h> +#include <sqltypes.h> +#include <sqlucode.h> + +#include <type_traits> + +#include "arrow/flight/sql/odbc/odbc_impl/odbc_connection.h" + +// For DSN registration +#include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h" + +#define TEST_CONNECT_STR "ARROW_FLIGHT_SQL_ODBC_CONN" +#define TEST_DSN "Apache Arrow Flight SQL Test DSN" + +namespace arrow::flight::sql::odbc { + +class FlightSQLODBCRemoteTestBase : public ::testing::Test { + public: + /// \brief Allocate environment and connection handles + void AllocEnvConnHandles(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + /// \brief Connect to Arrow Flight SQL server using connection string defined in + /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN", allocate statement handle. + /// Connects using ODBC Ver 3 by default + void Connect(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + /// \brief Connect to Arrow Flight SQL server using connection string + void ConnectWithString(std::string connection_str); + /// \brief Disconnect from server + void Disconnect(); + /// \brief Get connection string from environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual GetConnectionString(); + /// \brief Get invalid connection string based on connection string defined in + /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual GetInvalidConnectionString(); + /// \brief Return a SQL query that selects all data types + std::wstring virtual GetQueryAllDataTypes(); + + /** ODBC Environment. */ + SQLHENV env = 0; + + /** ODBC Connect. */ + SQLHDBC conn = 0; + + /** ODBC Statement. */ + SQLHSTMT stmt = 0; + + protected: + void SetUp() override; +}; + +static constexpr std::string_view kAuthorizationHeader = "authorization"; +static constexpr std::string_view kBearerPrefix = "Bearer "; +static constexpr std::string_view kTestToken = "t0k3n"; + +std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers); + +// A server middleware for validating incoming bearer header authentication. +class MockServerMiddleware : public ServerMiddleware { + public: + explicit MockServerMiddleware(const CallHeaders& incoming_headers, bool* is_valid) + : is_valid_(is_valid) { + incoming_headers_ = incoming_headers; + } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "MockServerMiddleware"; } + + private: + CallHeaders incoming_headers_; + bool* is_valid_; +}; + +// Factory for base64 header authentication testing. +class MockServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + MockServerMiddlewareFactory() : is_valid_(false) {} + + Status StartCall(const CallInfo& info, const ServerCallContext& context, + std::shared_ptr<ServerMiddleware>* middleware) override; + + private: + bool is_valid_; +}; + +class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase { + // Sets up a mock server for each test case + public: + /// \brief Get connection string for mock server + std::string GetConnectionString() override; + /// \brief Get invalid connection string for mock server + std::string GetInvalidConnectionString() override; + /// \brief Return a SQL query that selects all data types + std::wstring GetQueryAllDataTypes() override; + + /// \brief Run a SQL query to create default table for table test cases + void CreateTestTables(); + + /// \brief run a SQL query to create a table with all data types + void CreateTableAllDataType(); + /// \brief run a SQL query to create a table with unicode name + void CreateUnicodeTable(); + + int port; + + protected: + void SetUp() override; + + void TearDown() override; + + private: + std::shared_ptr<arrow::flight::sql::example::SQLiteFlightSqlServer> server_; +}; + +template <typename T> +class FlightSQLODBCTestBase : public T { + public: + using List = std::list<T>; +}; + +using TestTypes = + ::testing::Types<FlightSQLODBCMockTestBase, FlightSQLODBCRemoteTestBase>; +TYPED_TEST_SUITE(FlightSQLODBCTestBase, TestTypes); + +/** ODBC read buffer size. */ +enum { ODBC_BUFFER_SIZE = 1024 }; Review Comment: Now using a constexpr int. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
