lidavidm commented on code in PR #14082:
URL: https://github.com/apache/arrow/pull/14082#discussion_r1041425527


##########
cpp/src/arrow/flight/sql/driver.cc:
##########
@@ -0,0 +1,1877 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <array>
+#include <cmath>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_union.h"
+#include "arrow/c/bridge.h"
+#include "arrow/config.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/sql/client.h"
+#include "arrow/flight/sql/driver_internal.h"
+#include "arrow/flight/sql/server.h"
+#include "arrow/flight/sql/types.h"
+#include "arrow/flight/sql/visibility.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/type_fwd.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/config.h"
+#include "arrow/util/logging.h"
+
+#ifdef ARROW_COMPUTE
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/exec.h"
+#endif
+
+namespace arrow::flight::sql {
+
+using arrow::internal::checked_cast;
+
+namespace {
+/// \brief Client-side configuration to help paper over SQL dialect differences
+struct FlightSqlQuirks {
+  /// A mapping from Arrow type to SQL type string
+  std::unordered_map<Type::type, std::string> ingest_type_mapping;
+
+  FlightSqlQuirks() {
+    ingest_type_mapping[Type::BINARY] = "BLOB";
+    ingest_type_mapping[Type::BOOL] = "BOOLEAN";
+    ingest_type_mapping[Type::DATE32] = "DATE";
+    ingest_type_mapping[Type::DATE64] = "DATE";
+    ingest_type_mapping[Type::DECIMAL128] = "NUMERIC";
+    ingest_type_mapping[Type::DECIMAL256] = "NUMERIC";
+    ingest_type_mapping[Type::DOUBLE] = "DOUBLE PRECISION";
+    ingest_type_mapping[Type::FLOAT] = "REAL";
+    ingest_type_mapping[Type::INT16] = "SMALLINT";
+    ingest_type_mapping[Type::INT32] = "INT";
+    ingest_type_mapping[Type::INT64] = "BIGINT";
+    ingest_type_mapping[Type::LARGE_BINARY] = "BLOB";
+    ingest_type_mapping[Type::LARGE_STRING] = "TEXT";
+    ingest_type_mapping[Type::STRING] = "TEXT";
+    ingest_type_mapping[Type::TIME32] = "TIME";
+    ingest_type_mapping[Type::TIME64] = "TIME";
+    ingest_type_mapping[Type::TIMESTAMP] = "TIMESTAMP";
+  }
+
+  bool UpdateTypeMapping(std::string_view type_name, const char* value) {
+    if (type_name == "binary") {
+      ingest_type_mapping[Type::BINARY] = value;
+    } else if (type_name == "bool") {
+      ingest_type_mapping[Type::BOOL] = value;
+    } else if (type_name == "date32") {
+      ingest_type_mapping[Type::DATE32] = value;
+    } else if (type_name == "date64") {
+      ingest_type_mapping[Type::DATE64] = value;
+    } else if (type_name == "decimal128") {
+      ingest_type_mapping[Type::DECIMAL128] = value;
+    } else if (type_name == "decimal256") {
+      ingest_type_mapping[Type::DECIMAL256] = value;
+    } else if (type_name == "double") {
+      ingest_type_mapping[Type::DOUBLE] = value;
+    } else if (type_name == "float") {
+      ingest_type_mapping[Type::FLOAT] = value;
+    } else if (type_name == "int16") {
+      ingest_type_mapping[Type::INT16] = value;
+    } else if (type_name == "int32") {
+      ingest_type_mapping[Type::INT32] = value;
+    } else if (type_name == "int64") {
+      ingest_type_mapping[Type::INT64] = value;
+    } else if (type_name == "large_binary") {
+      ingest_type_mapping[Type::LARGE_BINARY] = value;
+    } else if (type_name == "large_string") {
+      ingest_type_mapping[Type::LARGE_STRING] = value;
+    } else if (type_name == "string") {
+      ingest_type_mapping[Type::STRING] = value;
+    } else if (type_name == "time32") {
+      ingest_type_mapping[Type::TIME32] = value;
+    } else if (type_name == "time64") {
+      ingest_type_mapping[Type::TIME64] = value;
+    } else if (type_name == "timestamp") {
+      ingest_type_mapping[Type::TIMESTAMP] = value;
+    } else {
+      return false;
+    }
+    return true;
+  }
+};
+
+/// Config options used to override the type mapping in FlightSqlQuirks
+constexpr std::string_view kIngestTypePrefix = 
"arrow.flight.sql.quirks.ingest_type.";
+/// Explicitly specify the Substrait version for Flight SQL (although
+/// Substrait will eventually embed this into the plan itself)
+constexpr std::string_view kStatementSubstraitVersionKey =
+    "arrow.flight.sql.substrait.version";
+/// Attach arbitrary key-value headers via Flight
+constexpr std::string_view kCallHeaderPrefix = 
"arrow.flight.sql.rpc.call_header.";
+/// A timeout for any DoGet requests
+constexpr std::string_view kConnectionTimeoutFetchKey =
+    "arrow.flight.sql.rpc.timeout_seconds.fetch";
+/// A timeout for any GetFlightInfo requests
+constexpr std::string_view kConnectionTimeoutQueryKey =
+    "arrow.flight.sql.rpc.timeout_seconds.query";
+/// A timeout for any DoPut requests, or miscellaneous DoAction requests
+constexpr std::string_view kConnectionTimeoutUpdateKey =
+    "arrow.flight.sql.rpc.timeout_seconds.update";
+constexpr std::string_view kConnectionOptionAutocommit =
+    ADBC_CONNECTION_OPTION_AUTOCOMMIT;
+constexpr std::string_view kIngestOptionMode = ADBC_INGEST_OPTION_MODE;
+constexpr std::string_view kIngestOptionModeAppend = 
ADBC_INGEST_OPTION_MODE_APPEND;
+constexpr std::string_view kIngestOptionModeCreate = 
ADBC_INGEST_OPTION_MODE_CREATE;
+constexpr std::string_view kIngestOptionTargetTable = 
ADBC_INGEST_OPTION_TARGET_TABLE;
+constexpr std::string_view kOptionValueEnabled = ADBC_OPTION_VALUE_ENABLED;
+constexpr std::string_view kOptionValueDisabled = ADBC_OPTION_VALUE_DISABLED;
+
+enum class CallContext {
+  kFetch,
+  kQuery,
+  kUpdate,
+};
+
+/// \brief AdbcDatabase implementation
+class FlightSqlDatabaseImpl {
+ public:
+  FlightSqlDatabaseImpl() : client_(nullptr) {
+    quirks_ = std::make_shared<FlightSqlQuirks>();
+  }
+
+  FlightSqlClient* Connect() {
+    std::lock_guard<std::mutex> guard(mutex_);
+    if (client_) ++connection_count_;
+    return client_.get();
+  }
+
+  const std::shared_ptr<FlightSqlQuirks>& quirks() const { return quirks_; }
+  [[nodiscard]] FlightCallOptions MakeCallOptions(CallContext context) const {
+    FlightCallOptions options;
+    for (const auto& header : call_headers_) {
+      options.headers.emplace_back(header.first, header.second);
+    }
+    return options;
+  }
+
+  AdbcStatusCode Init(struct AdbcError* error) {
+    if (client_) {
+      SetError(error, "Database already initialized");
+      return ADBC_STATUS_INVALID_STATE;
+    }
+    auto it = options_.find("uri");
+    if (it == options_.end()) {
+      SetError(error, "Must provide 'uri' option");
+      return ADBC_STATUS_INVALID_ARGUMENT;
+    }
+
+    Location location;
+    ADBC_ARROW_RETURN_NOT_OK(INVALID_ARGUMENT, error,
+                             Location::Parse(it->second).Value(&location));
+
+    FlightClientOptions client_options = DefaultClientOptions();
+    std::unique_ptr<FlightClient> flight_client;
+    ADBC_ARROW_RETURN_NOT_OK(
+        IO, error, FlightClient::Connect(location, 
client_options).Value(&flight_client));
+
+    client_ = std::make_unique<FlightSqlClient>(std::move(flight_client));
+    options_.clear();
+    return ADBC_STATUS_OK;
+  }
+
+  AdbcStatusCode SetOption(const char* key, const char* value, struct 
AdbcError* error) {
+    if (key == nullptr) {
+      SetError(error, "Key must not be null");
+      return ADBC_STATUS_INVALID_ARGUMENT;
+    }
+
+    std::string_view key_view(key);
+    std::string_view val_view = value ? value : "";
+    if (key_view.rfind(kIngestTypePrefix, 0) == 0) {
+      const std::string_view type_name = 
key_view.substr(kIngestTypePrefix.size());
+      if (!quirks_->UpdateTypeMapping(type_name, value)) {
+        SetError(error, "Unknown option value ", key_view, "=", val_view, ": 
type name ",
+                 type_name, " is not recognized");
+        return ADBC_STATUS_INVALID_ARGUMENT;
+      }
+      return ADBC_STATUS_OK;
+    } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) {
+      std::string header(key_view.substr(kCallHeaderPrefix.size()));
+      if (value == nullptr) {
+        call_headers_.erase(header);
+      } else {
+        call_headers_.insert({std::move(header), std::string(val_view)});
+      }
+      return ADBC_STATUS_OK;
+    }
+
+    if (client_) {
+      SetError(error, "Database already initialized");
+      return ADBC_STATUS_INVALID_STATE;
+    }
+    options_[std::string(key_view)] = std::string(val_view);
+    return ADBC_STATUS_OK;
+  }
+
+  AdbcStatusCode Disconnect(struct AdbcError* error) {
+    std::lock_guard<std::mutex> guard(mutex_);
+    if (--connection_count_ < 0) {
+      SetError(error, "Connection count underflow");
+      return ADBC_STATUS_INTERNAL;
+    }
+    return ADBC_STATUS_OK;
+  }
+
+  AdbcStatusCode Release(struct AdbcError* error) {
+    std::lock_guard<std::mutex> guard(mutex_);
+
+    if (connection_count_ > 0) {
+      SetError(error, "Cannot release database with ", connection_count_,
+               " open connections");
+      return ADBC_STATUS_INTERNAL;
+    }
+
+    if (client_) {
+      auto status = client_->Close();
+      client_.reset();
+      if (!status.ok()) {
+        SetError(error, status);
+        return ADBC_STATUS_IO;
+      }
+    }
+    return ADBC_STATUS_OK;
+  }
+
+ private:
+  std::unique_ptr<FlightSqlClient> client_;
+  std::shared_ptr<FlightSqlQuirks> quirks_;
+  std::unordered_map<std::string, std::string> call_headers_;
+  std::unordered_map<std::string, std::string> options_;
+  std::mutex mutex_;
+  int connection_count_ = 0;
+};
+
+/// \brief A RecordBatchReader that reads the endpoints of a FlightInfo
+class FlightInfoReader : public RecordBatchReader {
+ public:
+  explicit FlightInfoReader(FlightSqlClient* client, FlightCallOptions 
call_options,
+                            std::unique_ptr<FlightInfo> info)
+      : client_(client),
+        call_options_(std::move(call_options)),
+        info_(std::move(info)),
+        next_endpoint_(0) {}
+
+  [[nodiscard]] std::shared_ptr<Schema> schema() const override { return 
schema_; }
+
+  Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+    FlightStreamChunk chunk;
+    while (current_stream_ && !chunk.data) {
+      ARROW_ASSIGN_OR_RAISE(chunk, current_stream_->Next());
+      if (chunk.data) {
+        *batch = chunk.data;
+        break;
+      }
+      if (!chunk.data && !chunk.app_metadata) {
+        RETURN_NOT_OK(NextStream());
+      }
+    }
+    if (!current_stream_) *batch = nullptr;
+    return Status::OK();
+  }
+
+  Status Close() override {
+    if (current_stream_) {
+      current_stream_->Cancel();
+    }
+    return Status::OK();
+  }
+
+  AdbcStatusCode Init(struct AdbcError* error) {
+    ADBC_ARROW_RETURN_NOT_OK(IO, error, NextStream());
+    if (!schema_) {
+      // Empty result set - fall back on schema in FlightInfo
+      ipc::DictionaryMemo memo;
+      ADBC_ARROW_RETURN_NOT_OK(INTERNAL, error, 
info_->GetSchema(&memo).Value(&schema_));
+    }
+    return ADBC_STATUS_OK;
+  }
+
+  /// \brief Export to an ArrowArrayStream
+  static AdbcStatusCode Export(FlightSqlClient* client, FlightCallOptions 
call_options,
+                               std::unique_ptr<FlightInfo> info,
+                               struct ArrowArrayStream* stream, struct 
AdbcError* error) {
+    auto reader = std::make_shared<FlightInfoReader>(client, 
std::move(call_options),
+                                                     std::move(info));
+    ADBC_RETURN_NOT_OK(reader->Init(error));
+    ADBC_ARROW_RETURN_NOT_OK(INTERNAL, error,
+                             ExportRecordBatchReader(std::move(reader), 
stream));
+    return ADBC_STATUS_OK;
+  }
+
+ private:
+  Status NextStream() {
+    if (next_endpoint_ >= info_->endpoints().size()) {
+      current_stream_ = nullptr;
+      return Status::OK();
+    }
+    const FlightEndpoint& endpoint = info_->endpoints()[next_endpoint_];
+
+    if (endpoint.locations.empty()) {
+      ARROW_ASSIGN_OR_RAISE(current_stream_,
+                            client_->DoGet(call_options_, endpoint.ticket));
+    } else {
+      // TODO(lidavidm): this should come from a connection pool
+      std::string failures;
+      for (const Location& location : endpoint.locations) {
+        auto status =
+            FlightClient::Connect(location, 
DefaultClientOptions()).Value(&data_client_);
+        if (!status.ok()) {
+          if (!failures.empty()) {
+            failures += "; ";
+          }
+          failures += location.ToString();
+          failures += ": ";
+          failures += status.ToString();
+          data_client_.reset();
+          continue;
+        }
+        break;
+      }
+
+      if (!data_client_) {
+        return Status::IOError("Failed to connect to all endpoints: ", 
failures);
+      }
+      ARROW_ASSIGN_OR_RAISE(current_stream_,
+                            data_client_->DoGet(call_options_, 
endpoint.ticket));

Review Comment:
   We can't reset it here since we still need to use the stream (I guess the 
stream probably contains a strong reference to the client internally).
   
   I'll add the fallback for DoGet as well, good point



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to