lidavidm commented on code in PR #14082: URL: https://github.com/apache/arrow/pull/14082#discussion_r1041414544
########## cpp/src/arrow/flight/sql/driver.cc: ########## @@ -0,0 +1,1877 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <array> +#include <cmath> +#include <memory> +#include <mutex> +#include <string> +#include <string_view> +#include <unordered_map> + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_base.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/builder_union.h" +#include "arrow/c/bridge.h" +#include "arrow/config.h" +#include "arrow/flight/client.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/driver_internal.h" +#include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/types.h" +#include "arrow/flight/sql/visibility.h" +#include "arrow/io/memory.h" +#include "arrow/io/type_fwd.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/ipc/reader.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/config.h" +#include "arrow/util/logging.h" + +#ifdef ARROW_COMPUTE +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/exec.h" +#endif + +namespace arrow::flight::sql { + +using arrow::internal::checked_cast; + +namespace { +/// \brief Client-side configuration to help paper over SQL dialect differences +struct FlightSqlQuirks { + /// A mapping from Arrow type to SQL type string + std::unordered_map<Type::type, std::string> ingest_type_mapping; + + FlightSqlQuirks() { + ingest_type_mapping[Type::BINARY] = "BLOB"; + ingest_type_mapping[Type::BOOL] = "BOOLEAN"; + ingest_type_mapping[Type::DATE32] = "DATE"; + ingest_type_mapping[Type::DATE64] = "DATE"; + ingest_type_mapping[Type::DECIMAL128] = "NUMERIC"; + ingest_type_mapping[Type::DECIMAL256] = "NUMERIC"; + ingest_type_mapping[Type::DOUBLE] = "DOUBLE PRECISION"; + ingest_type_mapping[Type::FLOAT] = "REAL"; + ingest_type_mapping[Type::INT16] = "SMALLINT"; + ingest_type_mapping[Type::INT32] = "INT"; + ingest_type_mapping[Type::INT64] = "BIGINT"; + ingest_type_mapping[Type::LARGE_BINARY] = "BLOB"; + ingest_type_mapping[Type::LARGE_STRING] = "TEXT"; + ingest_type_mapping[Type::STRING] = "TEXT"; + ingest_type_mapping[Type::TIME32] = "TIME"; + ingest_type_mapping[Type::TIME64] = "TIME"; + ingest_type_mapping[Type::TIMESTAMP] = "TIMESTAMP"; + } + + bool UpdateTypeMapping(std::string_view type_name, const char* value) { + if (type_name == "binary") { + ingest_type_mapping[Type::BINARY] = value; + } else if (type_name == "bool") { + ingest_type_mapping[Type::BOOL] = value; + } else if (type_name == "date32") { + ingest_type_mapping[Type::DATE32] = value; + } else if (type_name == "date64") { + ingest_type_mapping[Type::DATE64] = value; + } else if (type_name == "decimal128") { + ingest_type_mapping[Type::DECIMAL128] = value; + } else if (type_name == "decimal256") { + ingest_type_mapping[Type::DECIMAL256] = value; + } else if (type_name == "double") { + ingest_type_mapping[Type::DOUBLE] = value; + } else if (type_name == "float") { + ingest_type_mapping[Type::FLOAT] = value; + } else if (type_name == "int16") { + ingest_type_mapping[Type::INT16] = value; + } else if (type_name == "int32") { + ingest_type_mapping[Type::INT32] = value; + } else if (type_name == "int64") { + ingest_type_mapping[Type::INT64] = value; + } else if (type_name == "large_binary") { + ingest_type_mapping[Type::LARGE_BINARY] = value; + } else if (type_name == "large_string") { + ingest_type_mapping[Type::LARGE_STRING] = value; + } else if (type_name == "string") { + ingest_type_mapping[Type::STRING] = value; + } else if (type_name == "time32") { + ingest_type_mapping[Type::TIME32] = value; + } else if (type_name == "time64") { + ingest_type_mapping[Type::TIME64] = value; + } else if (type_name == "timestamp") { + ingest_type_mapping[Type::TIMESTAMP] = value; + } else { + return false; + } + return true; + } +}; + +/// Config options used to override the type mapping in FlightSqlQuirks +constexpr std::string_view kIngestTypePrefix = "arrow.flight.sql.quirks.ingest_type."; +/// Explicitly specify the Substrait version for Flight SQL (although +/// Substrait will eventually embed this into the plan itself) +constexpr std::string_view kStatementSubstraitVersionKey = + "arrow.flight.sql.substrait.version"; +/// Attach arbitrary key-value headers via Flight +constexpr std::string_view kCallHeaderPrefix = "arrow.flight.sql.rpc.call_header."; +/// A timeout for any DoGet requests +constexpr std::string_view kConnectionTimeoutFetchKey = + "arrow.flight.sql.rpc.timeout_seconds.fetch"; +/// A timeout for any GetFlightInfo requests +constexpr std::string_view kConnectionTimeoutQueryKey = + "arrow.flight.sql.rpc.timeout_seconds.query"; +/// A timeout for any DoPut requests, or miscellaneous DoAction requests +constexpr std::string_view kConnectionTimeoutUpdateKey = + "arrow.flight.sql.rpc.timeout_seconds.update"; +constexpr std::string_view kConnectionOptionAutocommit = + ADBC_CONNECTION_OPTION_AUTOCOMMIT; +constexpr std::string_view kIngestOptionMode = ADBC_INGEST_OPTION_MODE; +constexpr std::string_view kIngestOptionModeAppend = ADBC_INGEST_OPTION_MODE_APPEND; +constexpr std::string_view kIngestOptionModeCreate = ADBC_INGEST_OPTION_MODE_CREATE; +constexpr std::string_view kIngestOptionTargetTable = ADBC_INGEST_OPTION_TARGET_TABLE; +constexpr std::string_view kOptionValueEnabled = ADBC_OPTION_VALUE_ENABLED; +constexpr std::string_view kOptionValueDisabled = ADBC_OPTION_VALUE_DISABLED; + +enum class CallContext { + kFetch, + kQuery, + kUpdate, +}; + +/// \brief AdbcDatabase implementation +class FlightSqlDatabaseImpl { + public: + FlightSqlDatabaseImpl() : client_(nullptr) { + quirks_ = std::make_shared<FlightSqlQuirks>(); + } + + FlightSqlClient* Connect() { + std::lock_guard<std::mutex> guard(mutex_); + if (client_) ++connection_count_; + return client_.get(); + } + + const std::shared_ptr<FlightSqlQuirks>& quirks() const { return quirks_; } + [[nodiscard]] FlightCallOptions MakeCallOptions(CallContext context) const { + FlightCallOptions options; + for (const auto& header : call_headers_) { + options.headers.emplace_back(header.first, header.second); + } + return options; + } + + AdbcStatusCode Init(struct AdbcError* error) { + if (client_) { + SetError(error, "Database already initialized"); + return ADBC_STATUS_INVALID_STATE; + } + auto it = options_.find("uri"); + if (it == options_.end()) { + SetError(error, "Must provide 'uri' option"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + Location location; + ADBC_ARROW_RETURN_NOT_OK(INVALID_ARGUMENT, error, + Location::Parse(it->second).Value(&location)); + + FlightClientOptions client_options = DefaultClientOptions(); + std::unique_ptr<FlightClient> flight_client; + ADBC_ARROW_RETURN_NOT_OK( + IO, error, FlightClient::Connect(location, client_options).Value(&flight_client)); + + client_ = std::make_unique<FlightSqlClient>(std::move(flight_client)); + options_.clear(); + return ADBC_STATUS_OK; + } + + AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) { + if (key == nullptr) { + SetError(error, "Key must not be null"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + std::string_view key_view(key); + std::string_view val_view = value ? value : ""; + if (key_view.rfind(kIngestTypePrefix, 0) == 0) { + const std::string_view type_name = key_view.substr(kIngestTypePrefix.size()); + if (!quirks_->UpdateTypeMapping(type_name, value)) { + SetError(error, "Unknown option value ", key_view, "=", val_view, ": type name ", + type_name, " is not recognized"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ADBC_STATUS_OK; + } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) { + std::string header(key_view.substr(kCallHeaderPrefix.size())); + if (value == nullptr) { + call_headers_.erase(header); + } else { + call_headers_.insert({std::move(header), std::string(val_view)}); + } + return ADBC_STATUS_OK; + } + + if (client_) { + SetError(error, "Database already initialized"); + return ADBC_STATUS_INVALID_STATE; + } + options_[std::string(key_view)] = std::string(val_view); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Disconnect(struct AdbcError* error) { + std::lock_guard<std::mutex> guard(mutex_); + if (--connection_count_ < 0) { + SetError(error, "Connection count underflow"); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode Release(struct AdbcError* error) { + std::lock_guard<std::mutex> guard(mutex_); + + if (connection_count_ > 0) { + SetError(error, "Cannot release database with ", connection_count_, + " open connections"); + return ADBC_STATUS_INTERNAL; + } + + if (client_) { + auto status = client_->Close(); + client_.reset(); + if (!status.ok()) { + SetError(error, status); + return ADBC_STATUS_IO; + } + } + return ADBC_STATUS_OK; + } + + private: + std::unique_ptr<FlightSqlClient> client_; + std::shared_ptr<FlightSqlQuirks> quirks_; + std::unordered_map<std::string, std::string> call_headers_; + std::unordered_map<std::string, std::string> options_; + std::mutex mutex_; + int connection_count_ = 0; +}; + +/// \brief A RecordBatchReader that reads the endpoints of a FlightInfo +class FlightInfoReader : public RecordBatchReader { + public: + explicit FlightInfoReader(FlightSqlClient* client, FlightCallOptions call_options, + std::unique_ptr<FlightInfo> info) + : client_(client), + call_options_(std::move(call_options)), + info_(std::move(info)), + next_endpoint_(0) {} + + [[nodiscard]] std::shared_ptr<Schema> schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { + FlightStreamChunk chunk; + while (current_stream_ && !chunk.data) { + ARROW_ASSIGN_OR_RAISE(chunk, current_stream_->Next()); + if (chunk.data) { + *batch = chunk.data; + break; + } + if (!chunk.data && !chunk.app_metadata) { + RETURN_NOT_OK(NextStream()); + } + } + if (!current_stream_) *batch = nullptr; + return Status::OK(); + } + + Status Close() override { + if (current_stream_) { + current_stream_->Cancel(); + } + return Status::OK(); + } + + AdbcStatusCode Init(struct AdbcError* error) { + ADBC_ARROW_RETURN_NOT_OK(IO, error, NextStream()); + if (!schema_) { + // Empty result set - fall back on schema in FlightInfo + ipc::DictionaryMemo memo; + ADBC_ARROW_RETURN_NOT_OK(INTERNAL, error, info_->GetSchema(&memo).Value(&schema_)); + } + return ADBC_STATUS_OK; + } + + /// \brief Export to an ArrowArrayStream + static AdbcStatusCode Export(FlightSqlClient* client, FlightCallOptions call_options, + std::unique_ptr<FlightInfo> info, + struct ArrowArrayStream* stream, struct AdbcError* error) { + auto reader = std::make_shared<FlightInfoReader>(client, std::move(call_options), + std::move(info)); + ADBC_RETURN_NOT_OK(reader->Init(error)); + ADBC_ARROW_RETURN_NOT_OK(INTERNAL, error, + ExportRecordBatchReader(std::move(reader), stream)); + return ADBC_STATUS_OK; + } + + private: + Status NextStream() { + if (next_endpoint_ >= info_->endpoints().size()) { + current_stream_ = nullptr; + return Status::OK(); + } + const FlightEndpoint& endpoint = info_->endpoints()[next_endpoint_]; + + if (endpoint.locations.empty()) { + ARROW_ASSIGN_OR_RAISE(current_stream_, + client_->DoGet(call_options_, endpoint.ticket)); + } else { + // TODO(lidavidm): this should come from a connection pool + std::string failures; + for (const Location& location : endpoint.locations) { + auto status = + FlightClient::Connect(location, DefaultClientOptions()).Value(&data_client_); + if (!status.ok()) { + if (!failures.empty()) { + failures += "; "; + } + failures += location.ToString(); + failures += ": "; + failures += status.ToString(); + data_client_.reset(); + continue; + } + break; + } + + if (!data_client_) { + return Status::IOError("Failed to connect to all endpoints: ", failures); + } + ARROW_ASSIGN_OR_RAISE(current_stream_, + data_client_->DoGet(call_options_, endpoint.ticket)); + } + next_endpoint_++; + if (!schema_) { + ARROW_ASSIGN_OR_RAISE(schema_, current_stream_->GetSchema()); + } + return Status::OK(); + } + + FlightSqlClient* client_; + FlightCallOptions call_options_; + std::unique_ptr<FlightInfo> info_; + size_t next_endpoint_; + std::shared_ptr<Schema> schema_; + std::unique_ptr<FlightStreamReader> current_stream_; + // TODO(lidavidm): use a common pool of cached clients with expiration + std::unique_ptr<FlightClient> data_client_; +}; + +class FlightSqlConnectionImpl { + public: + //---------------------------------------------------------- + // Common Functions + //---------------------------------------------------------- + + [[nodiscard]] FlightSqlClient* client() const { return client_; } + [[nodiscard]] const FlightSqlQuirks& quirks() const { return *quirks_; } + [[nodiscard]] const Transaction& transaction() const { return transaction_; } + [[nodiscard]] FlightCallOptions MakeCallOptions(CallContext context) const { + FlightCallOptions options = database_->MakeCallOptions(context); + auto it = timeout_seconds_.find(context); + if (it != timeout_seconds_.end()) { + options.timeout = it->second; + } + for (const auto& header : call_headers_) { + options.headers.emplace_back(header.first, header.second); + } + return options; + } + + AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error) { + if (!database->private_data) { + SetError(error, "database is not initialized"); + return ADBC_STATUS_INVALID_STATE; + } + + database_ = *reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>( + database->private_data); + client_ = database_->Connect(); + if (!client_) { + SetError(error, "Database not yet initialized!"); + return ADBC_STATUS_INVALID_STATE; + } + quirks_ = database_->quirks(); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Close(struct AdbcError* error) { + if (database_) { + ADBC_RETURN_NOT_OK(database_->Disconnect(error)); + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) { + if (key == nullptr) { + SetError(error, "Key must not be null"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto set_timeout_option = [=](CallContext context) -> AdbcStatusCode { + double timeout = 0.0; + size_t pos = 0; + const size_t len = std::strlen(value); + try { + timeout = std::stod(std::string(value, len), &pos); + } catch (const std::exception& e) { + SetError(error, "Invalid option value ", key, '=', value, ": ", e.what()); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (pos != len) { + SetError(error, "Invalid option value ", key, '=', value, + ": trailing characters after numeric literal"); + return ADBC_STATUS_INVALID_ARGUMENT; + } else if (std::isnan(timeout) || std::isinf(timeout) || timeout < 0) { + SetError(error, "Invalid option value ", key, '=', value, + ": timeout must be positive and finite"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + if (timeout == 0) { + timeout_seconds_.erase(context); + } else { + timeout_seconds_[context] = std::chrono::duration<double>(timeout); + } + return ADBC_STATUS_OK; + }; + + std::string_view key_view(key); + std::string_view val_view = value ? value : ""; + if (key == kConnectionOptionAutocommit) { + // TODO(lidavidm): should query server metadata to see if this is possible + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + if (val_view == kOptionValueEnabled) { + if (transaction_.is_valid()) { + ADBC_ARROW_RETURN_NOT_OK(IO, error, + client_->Commit(call_options, transaction_)); + transaction_ = no_transaction(); + } + return ADBC_STATUS_OK; + } else if (val_view == kOptionValueDisabled) { + if (transaction_.is_valid()) { + ADBC_ARROW_RETURN_NOT_OK(IO, error, + client_->Commit(call_options, transaction_)); + } + ADBC_ARROW_RETURN_NOT_OK( + IO, error, client_->BeginTransaction(call_options).Value(&transaction_)); + return ADBC_STATUS_OK; + } + SetError(error, "Invalid option value ", key_view, '=', val_view); + return ADBC_STATUS_INVALID_ARGUMENT; + } else if (key == kConnectionTimeoutFetchKey) { + return set_timeout_option(CallContext::kFetch); + } else if (key == kConnectionTimeoutQueryKey) { + return set_timeout_option(CallContext::kQuery); + } else if (key == kConnectionTimeoutUpdateKey) { + return set_timeout_option(CallContext::kUpdate); + } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) { + std::string header(key_view.substr(kCallHeaderPrefix.size())); + if (value == nullptr) { + call_headers_.erase(header); + } else { + call_headers_.insert({std::move(header), value}); + } + return ADBC_STATUS_OK; + } + SetError(error, "Unknown connection option ", key_view, '=', val_view); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + //---------------------------------------------------------- + // Metadata + //---------------------------------------------------------- + + AdbcStatusCode GetInfo(uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* stream, struct AdbcError* error) { + static std::shared_ptr<arrow::Schema> kInfoSchema = arrow::schema({ + arrow::field("info_name", arrow::uint32(), /*nullable=*/false), + arrow::field( + "info_value", + arrow::dense_union({ + arrow::field("string_value", arrow::utf8()), + arrow::field("bool_value", arrow::boolean()), + arrow::field("int64_value", arrow::int64()), + arrow::field("int32_bitmask", arrow::int32()), + arrow::field("string_list", arrow::list(arrow::utf8())), + arrow::field("int32_to_int32_list_map", + arrow::map(arrow::int32(), arrow::list(arrow::int32()))), + })), + }); + + // XXX(ARROW-17558): type should be uint32_t not int + std::vector<int> flight_sql_codes; + std::vector<uint32_t> codes; + if (info_codes && info_codes_length > 0) { + for (size_t i = 0; i < info_codes_length; i++) { + const uint32_t info_code = info_codes[i]; + switch (info_code) { + case ADBC_INFO_VENDOR_NAME: + case ADBC_INFO_VENDOR_VERSION: + case ADBC_INFO_VENDOR_ARROW_VERSION: + // These codes are equivalent between the two + flight_sql_codes.push_back(info_code); + break; + case ADBC_INFO_DRIVER_NAME: + case ADBC_INFO_DRIVER_VERSION: + case ADBC_INFO_DRIVER_ARROW_VERSION: + codes.push_back(info_code); + break; + default: + SetError(error, "Unknown info code: ", info_code); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + } else { + flight_sql_codes = { + SqlInfoOptions::FLIGHT_SQL_SERVER_NAME, + SqlInfoOptions::FLIGHT_SQL_SERVER_VERSION, + SqlInfoOptions::FLIGHT_SQL_SERVER_ARROW_VERSION, + }; + codes = { + ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ARROW_VERSION, + }; + } + + RecordBatchVector result; + + UInt32Builder names; + std::unique_ptr<ArrayBuilder> values; + ADBC_ARROW_RETURN_NOT_OK(INTERNAL, error, + MakeBuilder(kInfoSchema->field(1)->type()).Value(&values)); + auto* info_value = static_cast<DenseUnionBuilder*>(values.get()); + auto* info_string = static_cast<StringBuilder*>(info_value->child_builder(0).get()); + int64_t num_values = 0; + + constexpr int8_t kStringCode = 0; + + if (!flight_sql_codes.empty()) { + FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery); + std::unique_ptr<FlightInfo> info; + ADBC_ARROW_RETURN_NOT_OK( + IO, error, client_->GetSqlInfo(call_options, flight_sql_codes).Value(&info)); + FlightInfoReader reader(client_, MakeCallOptions(CallContext::kFetch), + std::move(info)); + ADBC_RETURN_NOT_OK(reader.Init(error)); + + if (!reader.schema()->Equals(*SqlSchema::GetSqlInfoSchema())) { + SetError(error, "Server returned wrong schema, got: ", *reader.schema()); + return ADBC_STATUS_INTERNAL; + } + + while (true) { + std::shared_ptr<RecordBatch> batch; + ADBC_ARROW_RETURN_NOT_OK(IO, error, reader.Next().Value(&batch)); + if (!batch) break; + + const auto& sql_codes = static_cast<const UInt32Array&>(*batch->column(0)); Review Comment: Actually above, we already validate the record batch reader's schema matches the expected schema -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
