lidavidm commented on code in PR #14266:
URL: https://github.com/apache/arrow/pull/14266#discussion_r1006890482


##########
cpp/src/arrow/flight/sql/client.cc:
##########
@@ -574,26 +584,21 @@ arrow::Result<int64_t> PreparedStatement::ExecuteUpdate(
   command.set_prepared_statement_handle(handle_);
   ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
                         GetFlightDescriptorForCommand(command));
-  std::unique_ptr<FlightStreamWriter> writer;
-  std::unique_ptr<FlightMetadataReader> reader;
-
-  if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
-    ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, 
parameter_binding_->schema(),
-                                       &writer, &reader));
-    ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
+  std::shared_ptr<Buffer> metadata;
+  if (parameter_binding_) {
+    ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), 
options,
+                                                   descriptor, 
parameter_binding_.get()));
   } else {
     const std::shared_ptr<Schema> schema = arrow::schema({});
-    ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, schema, &writer, 
&reader));
-    const ArrayVector columns;
-    const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns);
-    ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch));
+    auto record_batch = arrow::RecordBatch::Make(schema, 0, ArrayVector{});
+    ARROW_ASSIGN_OR_RAISE(auto params,
+                          RecordBatchReader::Make({std::move(record_batch)}));

Review Comment:
   It's not written into the spec. It also doesn't seem necessary here. I've 
removed it (or rather, replaced it with a 0-batch reader).



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -77,50 +84,25 @@ std::string PrepareQueryForGetTables(const GetTables& 
command) {
   return table_query.str();
 }
 
-Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* 
reader) {
+template <typename Callback>
+Status SetParametersOnSQLiteStatement(SqliteStatement* statement,
+                                      FlightMessageReader* reader, Callback 
callback) {
+  sqlite3_stmt* stmt = statement->GetSqlite3Stmt();
   while (true) {
     ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next());
-    std::shared_ptr<RecordBatch>& record_batch = chunk.data;
-    if (record_batch == nullptr) break;
+    if (chunk.data == nullptr) break;
 
-    const int64_t num_rows = record_batch->num_rows();
-    const int& num_columns = record_batch->num_columns();
+    const int64_t num_rows = chunk.data->num_rows();
+    if (num_rows == 0) continue;
 
+    ARROW_RETURN_NOT_OK(statement->SetParameters({std::move(chunk.data)}));
     for (int i = 0; i < num_rows; ++i) {
-      for (int c = 0; c < num_columns; ++c) {
-        const std::shared_ptr<Array>& column = record_batch->column(c);
-        ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
column->GetScalar(i));
-
-        auto& holder = static_cast<DenseUnionScalar&>(*scalar).value;
-
-        switch (holder->type->id()) {
-          case Type::INT64: {
-            int64_t value = static_cast<Int64Scalar&>(*holder).value;
-            sqlite3_bind_int64(stmt, c + 1, value);
-            break;
-          }
-          case Type::FLOAT: {
-            double value = static_cast<FloatScalar&>(*holder).value;
-            sqlite3_bind_double(stmt, c + 1, value);
-            break;
-          }
-          case Type::STRING: {
-            std::shared_ptr<Buffer> buffer = 
static_cast<StringScalar&>(*holder).value;
-            sqlite3_bind_text(stmt, c + 1, reinterpret_cast<const 
char*>(buffer->data()),
-                              static_cast<int>(buffer->size()), 
SQLITE_TRANSIENT);
-            break;
-          }
-          case Type::BINARY: {
-            std::shared_ptr<Buffer> buffer = 
static_cast<BinaryScalar&>(*holder).value;
-            sqlite3_bind_blob(stmt, c + 1, buffer->data(),
-                              static_cast<int>(buffer->size()), 
SQLITE_TRANSIENT);
-            break;
-          }
-          default:
-            return Status::Invalid("Received unsupported data type: ",
-                                   holder->type->ToString());
-        }
+      if (sqlite3_clear_bindings(stmt) != SQLITE_OK) {
+        return Status::Invalid("Failed to reset bindings on row ", i, ": ",
+                               sqlite3_errmsg(statement->db()));
       }
+      ARROW_RETURN_NOT_OK(statement->Bind(/*batch_index=*/0, i));
+      ARROW_RETURN_NOT_OK(callback());

Review Comment:
   We can't bind multiple rows at once, so we have to execute the query after 
each binding.



##########
cpp/src/arrow/flight/sql/example/sqlite_statement.h:
##########
@@ -62,15 +62,25 @@ class SqliteStatement {
 
   /// \brief Returns the underlying sqlite3_stmt.
   /// \return A sqlite statement.
-  sqlite3_stmt* GetSqlite3Stmt() const;
+  [[nodiscard]] sqlite3_stmt* GetSqlite3Stmt() const;

Review Comment:
   I've removed them all.



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -77,50 +84,25 @@ std::string PrepareQueryForGetTables(const GetTables& 
command) {
   return table_query.str();
 }
 
-Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* 
reader) {
+template <typename Callback>
+Status SetParametersOnSQLiteStatement(SqliteStatement* statement,
+                                      FlightMessageReader* reader, Callback 
callback) {
+  sqlite3_stmt* stmt = statement->GetSqlite3Stmt();
   while (true) {
     ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next());
-    std::shared_ptr<RecordBatch>& record_batch = chunk.data;
-    if (record_batch == nullptr) break;
+    if (chunk.data == nullptr) break;
 
-    const int64_t num_rows = record_batch->num_rows();
-    const int& num_columns = record_batch->num_columns();
+    const int64_t num_rows = chunk.data->num_rows();
+    if (num_rows == 0) continue;
 
+    ARROW_RETURN_NOT_OK(statement->SetParameters({std::move(chunk.data)}));
     for (int i = 0; i < num_rows; ++i) {
-      for (int c = 0; c < num_columns; ++c) {
-        const std::shared_ptr<Array>& column = record_batch->column(c);
-        ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
column->GetScalar(i));
-
-        auto& holder = static_cast<DenseUnionScalar&>(*scalar).value;
-
-        switch (holder->type->id()) {
-          case Type::INT64: {
-            int64_t value = static_cast<Int64Scalar&>(*holder).value;
-            sqlite3_bind_int64(stmt, c + 1, value);
-            break;
-          }
-          case Type::FLOAT: {
-            double value = static_cast<FloatScalar&>(*holder).value;
-            sqlite3_bind_double(stmt, c + 1, value);
-            break;
-          }
-          case Type::STRING: {
-            std::shared_ptr<Buffer> buffer = 
static_cast<StringScalar&>(*holder).value;
-            sqlite3_bind_text(stmt, c + 1, reinterpret_cast<const 
char*>(buffer->data()),
-                              static_cast<int>(buffer->size()), 
SQLITE_TRANSIENT);
-            break;
-          }
-          case Type::BINARY: {
-            std::shared_ptr<Buffer> buffer = 
static_cast<BinaryScalar&>(*holder).value;
-            sqlite3_bind_blob(stmt, c + 1, buffer->data(),
-                              static_cast<int>(buffer->size()), 
SQLITE_TRANSIENT);
-            break;
-          }
-          default:
-            return Status::Invalid("Received unsupported data type: ",
-                                   holder->type->ToString());
-        }
+      if (sqlite3_clear_bindings(stmt) != SQLITE_OK) {
+        return Status::Invalid("Failed to reset bindings on row ", i, ": ",
+                               sqlite3_errmsg(statement->db()));
       }
+      ARROW_RETURN_NOT_OK(statement->Bind(/*batch_index=*/0, i));

Review Comment:
   Yes, since there's only a single batch being fed through at a time.



##########
cpp/src/arrow/flight/sql/server_test.cc:
##########
@@ -502,51 +489,53 @@ TEST_F(TestFlightSqlServer, 
TestCommandPreparedStatementQueryWithParameterBindin
       auto prepared_statement,
       sql_client->Prepare({}, "SELECT * FROM intTable WHERE keyName LIKE ?"));
 
-  auto parameter_schema = prepared_statement->parameter_schema();
-
+  const std::shared_ptr<Schema>& parameter_schema =
+      prepared_statement->parameter_schema();
   const std::shared_ptr<Schema>& expected_parameter_schema =
       arrow::schema({arrow::field("parameter_1", 
example::GetUnknownColumnDataType())});
+  ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(expected_parameter_schema, 
parameter_schema));
 
-  AssertSchemaEqual(expected_parameter_schema, parameter_schema);
-
-  std::shared_ptr<Array> type_ids = ArrayFromJSON(int8(), R"([0])");
-  std::shared_ptr<Array> offsets = ArrayFromJSON(int32(), R"([0])");
-  std::shared_ptr<Array> string_array = ArrayFromJSON(utf8(), R"(["%one"])");
-  std::shared_ptr<Array> bytes_array = ArrayFromJSON(binary(), R"([])");
-  std::shared_ptr<Array> bigint_array = ArrayFromJSON(int64(), R"([])");
-  std::shared_ptr<Array> double_array = ArrayFromJSON(float64(), R"([])");
-
-  ASSERT_OK_AND_ASSIGN(
-      auto parameter_1_array,
-      DenseUnionArray::Make(*type_ids, *offsets,
-                            {string_array, bytes_array, bigint_array, 
double_array},
-                            {"string", "bytes", "bigint", "double"}, {0, 1, 2, 
3}));
-
-  const std::shared_ptr<RecordBatch>& record_batch =
-      RecordBatch::Make(parameter_schema, 1, {parameter_1_array});
-
-  ASSERT_OK(prepared_statement->SetParameters(record_batch));
+  auto record_batch = RecordBatchFromJSON(parameter_schema, R"([ [[0, "%one"]] 
])");
+  ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch)));
 
   ASSERT_OK_AND_ASSIGN(auto flight_info, prepared_statement->Execute());
-
   ASSERT_OK_AND_ASSIGN(auto stream,
                        sql_client->DoGet({}, 
flight_info->endpoints()[0].ticket));
-
   ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable());
 
   const std::shared_ptr<Schema>& expected_schema =
       arrow::schema({arrow::field("id", int64()), arrow::field("keyName", 
utf8()),
                      arrow::field("value", int64()), arrow::field("foreignId", 
int64())});
 
-  const auto id_array = ArrayFromJSON(int64(), R"([1, 3])");
-  const auto keyname_array = ArrayFromJSON(utf8(), R"(["one", "negative 
one"])");
-  const auto value_array = ArrayFromJSON(int64(), R"([1, -1])");
-  const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1])");
-
-  const std::shared_ptr<Table>& expected_table = Table::Make(
-      expected_schema, {id_array, keyname_array, value_array, 
foreignId_array});
-
-  AssertTablesEqual(*expected_table, *table);
+  auto expected_table = TableFromJSON(expected_schema, {R"([
+      [1, "one", 1, 1],
+      [3, "negative one", -1, 1]
+  ])"});
+  ASSERT_NO_FATAL_FAILURE(AssertTablesEqual(*expected_table, *table, 
/*verbose=*/true));
+
+  // Set multiple parameters at once
+  record_batch =
+      RecordBatchFromJSON(parameter_schema, R"([ [[0, "%one"]], [[0, "%zero"]] 
])");
+  ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch)));
+  ASSERT_OK_AND_ASSIGN(flight_info, prepared_statement->Execute());
+  ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, 
flight_info->endpoints()[0].ticket));
+  ASSERT_OK_AND_ASSIGN(table, stream->ToTable());
+  expected_table = TableFromJSON(expected_schema, {R"([
+      [1, "one", 1, 1],
+      [3, "negative one", -1, 1],
+      [2, "zero", 0, 1]

Review Comment:
   Adjusted some tests to return nulls.



##########
cpp/src/arrow/flight/sql/example/sqlite_statement.h:
##########
@@ -62,15 +62,25 @@ class SqliteStatement {
 
   /// \brief Returns the underlying sqlite3_stmt.
   /// \return A sqlite statement.
-  sqlite3_stmt* GetSqlite3Stmt() const;
+  [[nodiscard]] sqlite3_stmt* GetSqlite3Stmt() const;

Review Comment:
   The nodiscard thing is a quirk of the code linter in clangd; it suggests 
`[[nodiscard]]` on basically every getter due to 
https://clang.llvm.org/extra/clang-tidy/checks/modernize/use-nodiscard.html.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to