This is an automated email from the ASF dual-hosted git repository. lidavidm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push: new 0e03c922 fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742) 0e03c922 is described below commit 0e03c922afa3284425d1fce726e48aa971209888 Author: Kirill Müller <krl...@users.noreply.github.com> AuthorDate: Mon Jun 12 23:22:32 2023 +0200 fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742) Needs tests on the C side and perhaps also on the R side. Please advise. Closes #734. ``` r library(adbcdrivermanager) # pkgload::load_all() # Use the driver manager to connect to a database db <- adbc_database_init(adbcsqlite::adbcsqlite(), uri = ":memory:") con <- adbc_connection_init(db) # Write a table flights <- nycflights13::flights # (timestamp not supported yet) flights$time_hour <- NULL stmt <- adbc_statement_init(con, adbc.ingest.target_table = "flights") adbc_statement_bind(stmt, flights) adbc_statement_execute_query(stmt) #> [1] 336776 adbc_statement_release(stmt) # March flights stmt <- adbc_statement_init(con) adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = 3 LIMIT 2") stream <- nanoarrow::nanoarrow_allocate_array_stream() adbc_statement_execute_query(stmt, stream) #> [1] -1 result <- tibble::as_tibble(stream) adbc_statement_release(stmt) result #> # A tibble: 2 × 18 #> year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 2013 3 1 4 2159 125 318 56 #> 2 2013 3 1 50 2358 52 526 438 #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>, #> # tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>, #> # hour <dbl>, minute <dbl> # March flights with a parameter, not passing parameter stmt <- adbc_statement_init(con) adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ? LIMIT 2") stream <- nanoarrow::nanoarrow_allocate_array_stream() adbc_statement_execute_query(stmt, stream) #> [1] -1 result <- tibble::as_tibble(stream) adbc_statement_release(stmt) result #> # A tibble: 0 × 18 #> # ℹ 18 variables: year <dbl>, month <dbl>, day <dbl>, dep_time <dbl>, #> # sched_dep_time <dbl>, dep_delay <dbl>, arr_time <dbl>, #> # sched_arr_time <dbl>, arr_delay <dbl>, carrier <dbl>, flight <dbl>, #> # tailnum <dbl>, origin <dbl>, dest <dbl>, air_time <dbl>, distance <dbl>, #> # hour <dbl>, minute <dbl> # March flights with a parameter stmt <- adbc_statement_init(con) adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ? LIMIT 2") adbc_statement_bind_stream(stmt, data.frame(a = 3)) stream <- nanoarrow::nanoarrow_allocate_array_stream() adbc_statement_execute_query(stmt, stream) #> [1] -1 result <- tibble::as_tibble(stream) adbc_statement_release(stmt) result #> # A tibble: 2 × 18 #> year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 2013 3 1 4 2159 125 318 56 #> 2 2013 3 1 50 2358 52 526 438 #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>, #> # tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>, #> # hour <dbl>, minute <dbl> # Many March flights with multiple parameters stmt <- adbc_statement_init(con) adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ?") adbc_statement_bind_stream(stmt, data.frame(a = 2:4)) stream <- nanoarrow::nanoarrow_allocate_array_stream() adbc_statement_execute_query(stmt, stream) #> [1] -1 result <- tibble::as_tibble(stream) adbc_statement_release(stmt) result #> # A tibble: 24,951 × 18 #> year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 2013 2 1 456 500 -4 652 648 #> 2 2013 2 1 520 525 -5 816 820 #> 3 2013 2 1 527 530 -3 837 829 #> 4 2013 2 1 532 540 -8 1007 1017 #> 5 2013 2 1 540 540 0 859 850 #> 6 2013 2 1 552 600 -8 714 715 #> 7 2013 2 1 552 600 -8 919 910 #> 8 2013 2 1 552 600 -8 655 709 #> 9 2013 2 1 553 600 -7 833 815 #> 10 2013 2 1 553 600 -7 821 825 #> # ℹ 24,941 more rows #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>, #> # tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>, #> # hour <dbl>, minute <dbl> # Clean up adbc_connection_release(con) adbc_database_release(db) ``` <sup>Created on 2023-06-08 with [reprex v2.0.2](https://reprex.tidyverse.org)</sup> --------- Co-authored-by: David Li <li.david...@gmail.com> --- c/driver/sqlite/sqlite_test.cc | 41 ++++++++++++++++++ c/driver/sqlite/statement_reader.c | 85 +++++++++++++++++++++++--------------- c/validation/adbc_validation.cc | 3 ++ 3 files changed, 96 insertions(+), 33 deletions(-) diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index a7088850..5bcca0f7 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -536,6 +536,47 @@ TEST_F(SqliteReaderTest, InferTypedParams) { "[SQLite] Type mismatch in column 0: expected INT64 but got DOUBLE")); } +TEST_F(SqliteReaderTest, MultiValueParams) { + // Regression test for apache/arrow-adbc#734 + adbc_validation::StreamReader reader; + Handle<struct ArrowSchema> schema; + Handle<struct ArrowArray> batch; + + ASSERT_NO_FATAL_FAILURE(Exec("CREATE TABLE foo (col)")); + ASSERT_NO_FATAL_FAILURE( + Exec("INSERT INTO foo VALUES (1), (2), (2), (3), (3), (3), (4), (4), (4), (4)")); + + ASSERT_THAT(adbc_validation::MakeSchema(&schema.value, {{"", NANOARROW_TYPE_INT64}}), + IsOkErrno()); + ASSERT_THAT(adbc_validation::MakeBatch<int64_t>(&schema.value, &batch.value, + /*error=*/nullptr, {4, 1, 3, 2}), + IsOkErrno()); + + ASSERT_NO_FATAL_FAILURE(Bind(&batch.value, &schema.value)); + ASSERT_NO_FATAL_FAILURE( + Exec("SELECT col FROM foo WHERE col = ?", /*infer_rows=*/3, &reader)); + ASSERT_EQ(1, reader.schema->n_children); + ASSERT_EQ(NANOARROW_TYPE_INT64, reader.fields[0].type); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NO_FATAL_FAILURE( + CompareArray<int64_t>(reader.array_view->children[0], {4, 4, 4})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NO_FATAL_FAILURE( + CompareArray<int64_t>(reader.array_view->children[0], {4, 1, 3})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NO_FATAL_FAILURE( + CompareArray<int64_t>(reader.array_view->children[0], {3, 3, 2})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[0], {2})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); +} + template <typename CType> class SqliteNumericParamTest : public SqliteReaderTest, public ::testing::WithParamInterface<ArrowType> { diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c index abde44a2..2b17364a 100644 --- a/c/driver/sqlite/statement_reader.c +++ b/c/driver/sqlite/statement_reader.c @@ -382,35 +382,41 @@ int StatementReaderGetNext(struct ArrowArrayStream* self, struct ArrowArray* out sqlite3_mutex_enter(sqlite3_db_mutex(reader->db)); while (batch_size < reader->batch_size) { - if (reader->binder) { - char finished = 0; - struct AdbcError error = {0}; - AdbcStatusCode status = AdbcSqliteBinderBindNext(reader->binder, reader->db, - reader->stmt, &finished, &error); - if (status != ADBC_STATUS_OK) { - reader->done = 1; - status = EIO; - if (error.release) { - strncpy(reader->error.message, error.message, sizeof(reader->error.message)); - reader->error.message[sizeof(reader->error.message) - 1] = '\0'; - error.release(&error); - } - break; - } else if (finished) { + int rc = sqlite3_step(reader->stmt); + if (rc == SQLITE_DONE) { + if (!reader->binder) { reader->done = 1; break; + } else { + char finished = 0; + struct AdbcError error = {0}; + status = AdbcSqliteBinderBindNext(reader->binder, reader->db, reader->stmt, + &finished, &error); + if (status != ADBC_STATUS_OK) { + reader->done = 1; + status = EIO; + if (error.release) { + strncpy(reader->error.message, error.message, sizeof(reader->error.message)); + reader->error.message[sizeof(reader->error.message) - 1] = '\0'; + error.release(&error); + } + break; + } else if (finished) { + reader->done = 1; + break; + } + continue; } - } - - int rc = sqlite3_step(reader->stmt); - if (rc == SQLITE_DONE) { - reader->done = 1; - break; } else if (rc == SQLITE_ERROR) { reader->done = 1; status = EIO; StatementReaderSetError(reader); break; + } else if (rc != SQLITE_ROW) { + reader->done = 1; + status = ADBC_STATUS_INTERNAL; + StatementReaderSetError(reader); + break; } for (int col = 0; col < reader->schema.n_children; col++) { @@ -836,26 +842,39 @@ AdbcStatusCode AdbcSqliteExportReader(sqlite3* db, sqlite3_stmt* stmt, AdbcStatusCode status = StatementReaderInitializeInfer( num_columns, batch_size, validity, data, binary, current_type, error); - if (status == ADBC_STATUS_OK) { + + if (binder) { + char finished = 0; + status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error); + if (finished) { + reader->done = 1; + } + } + + if (status == ADBC_STATUS_OK && !reader->done) { int64_t num_rows = 0; while (num_rows < batch_size) { - if (binder) { - char finished = 0; - status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error); - if (status != ADBC_STATUS_OK) break; - if (finished) { + int rc = sqlite3_step(stmt); + if (rc == SQLITE_DONE) { + if (!binder) { reader->done = 1; break; + } else { + char finished = 0; + status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error); + if (status != ADBC_STATUS_OK) break; + if (finished) { + reader->done = 1; + break; + } } - } - - int rc = sqlite3_step(stmt); - if (rc == SQLITE_DONE) { - reader->done = 1; - break; + continue; } else if (rc == SQLITE_ERROR) { status = ADBC_STATUS_IO; break; + } else if (rc != SQLITE_ROW) { + status = ADBC_STATUS_INTERNAL; + break; } for (int col = 0; col < num_columns; col++) { diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index b99f469d..8c25f11f 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -1463,6 +1463,9 @@ void StatementTest::TestSqlPrepareSelectParams() { auto start = nrows; auto end = nrows + reader.array->length; + ASSERT_LT(start, expected_int32.size()); + ASSERT_LE(end, expected_int32.size()); + switch (reader.fields[0].type) { case NANOARROW_TYPE_INT32: ASSERT_NO_FATAL_FAILURE(CompareArray<int32_t>(