This is an automated email from the ASF dual-hosted git repository. lidavidm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push: new d44e3616d feat(c/driver/sqlite,python/adbc_driver_manager): bind params by name (#3362) d44e3616d is described below commit d44e3616dd5182f4278801998e3838134c98994b Author: David Li <li.david...@gmail.com> AuthorDate: Sat Sep 6 15:55:03 2025 +0900 feat(c/driver/sqlite,python/adbc_driver_manager): bind params by name (#3362) Closes #3262. --- c/driver/sqlite/sqlite.cc | 17 +++-- c/driver/sqlite/sqlite_test.cc | 38 +++++++++-- c/driver/sqlite/statement_reader.c | 56 +++++++++++---- c/driver/sqlite/statement_reader.h | 4 ++ .../adbc_driver_manager/__init__.py | 2 + .../adbc_driver_manager/_dbapi_backend.py | 69 ++++++++++++++----- .../adbc_driver_manager/dbapi.py | 42 ++++++++++-- python/adbc_driver_manager/tests/test_dbapi.py | 36 ++++++++++ .../tests/test_dbapi_polars_nopyarrow.py | 79 ++++++++++++++++++++++ 9 files changed, 295 insertions(+), 48 deletions(-) diff --git a/c/driver/sqlite/sqlite.cc b/c/driver/sqlite/sqlite.cc index 1a47a4f13..dc3a7de43 100644 --- a/c/driver/sqlite/sqlite.cc +++ b/c/driver/sqlite/sqlite.cc @@ -51,6 +51,7 @@ constexpr std::string_view kConnectionOptionLoadExtensionEntrypoint = "adbc.sqlite.load_extension.entrypoint"; /// The batch size for query results (and for initial type inference) constexpr std::string_view kStatementOptionBatchRows = "adbc.sqlite.query.batch_rows"; +constexpr std::string_view kStatementOptionBindByName = "adbc.statement.bind_by_name"; std::string_view GetColumnText(sqlite3_stmt* stmt, int index) { return { @@ -763,11 +764,11 @@ class SqliteStatement : public driver::Statement<SqliteStatement> { public: [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[SQLite]"; - Status BindImpl() { + Status BindImpl(bool ingest) { if (bind_parameters_.release) { struct AdbcError error = ADBC_ERROR_INIT; - if (AdbcStatusCode code = - InternalAdbcSqliteBinderSetArrayStream(&binder_, &bind_parameters_, &error); + if (AdbcStatusCode code = InternalAdbcSqliteBinderSetArrayStream( + &binder_, &bind_parameters_, !ingest && bind_by_name_, &error); code != ADBC_STATUS_OK) { return Status::FromAdbc(code, error); } @@ -776,7 +777,7 @@ class SqliteStatement : public driver::Statement<SqliteStatement> { } Result<int64_t> ExecuteIngestImpl(IngestState& state) { - UNWRAP_STATUS(BindImpl()); + UNWRAP_STATUS(BindImpl(true)); if (!binder_.schema.release) { return status::InvalidState("must Bind() before bulk ingestion"); } @@ -975,7 +976,7 @@ class SqliteStatement : public driver::Statement<SqliteStatement> { Result<int64_t> ExecuteQueryImpl(ArrowArrayStream* stream) { struct AdbcError error = ADBC_ERROR_INIT; - UNWRAP_STATUS(BindImpl()); + UNWRAP_STATUS(BindImpl(false)); const int64_t expected = sqlite3_bind_parameter_count(stmt_); const int64_t actual = binder_.schema.n_children; @@ -1003,7 +1004,7 @@ class SqliteStatement : public driver::Statement<SqliteStatement> { } Result<int64_t> ExecuteUpdateImpl() { - UNWRAP_STATUS(BindImpl()); + UNWRAP_STATUS(BindImpl(false)); const int64_t expected = sqlite3_bind_parameter_count(stmt_); const int64_t actual = binder_.schema.n_children; @@ -1143,11 +1144,15 @@ class SqliteStatement : public driver::Statement<SqliteStatement> { } batch_size_ = static_cast<int>(batch_size); return status::Ok(); + } else if (key == kStatementOptionBindByName) { + UNWRAP_RESULT(bind_by_name_, value.AsBool()); + return status::Ok(); } return Base::SetOptionImpl(key, std::move(value)); } int batch_size_ = 1024; + bool bind_by_name_ = false; AdbcSqliteBinder binder_; sqlite3* conn_ = nullptr; sqlite3_stmt* stmt_ = nullptr; diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index 62f15c690..f270f5059 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -454,17 +454,19 @@ class SqliteReaderTest : public ::testing::Test { stmt = nullptr; } - void Bind(struct ArrowArray* batch, struct ArrowSchema* schema) { + void Bind(struct ArrowArray* batch, struct ArrowSchema* schema, + bool bind_by_name = false) { Handle<struct ArrowArrayStream> stream; struct ArrowArray batch_internal = *batch; batch->release = nullptr; adbc_validation::MakeStream(&stream.value, schema, {batch_internal}); - ASSERT_NO_FATAL_FAILURE(Bind(&stream.value)); + ASSERT_NO_FATAL_FAILURE(Bind(&stream.value, bind_by_name)); } - void Bind(struct ArrowArrayStream* stream) { - ASSERT_THAT(InternalAdbcSqliteBinderSetArrayStream(&binder, stream, &error), - IsOkStatus(&error)); + void Bind(struct ArrowArrayStream* stream, bool bind_by_name = false) { + ASSERT_THAT( + InternalAdbcSqliteBinderSetArrayStream(&binder, stream, bind_by_name, &error), + IsOkStatus(&error)); } void ExecSelect(const std::string& values, size_t infer_rows, @@ -826,6 +828,32 @@ TEST_F(SqliteReaderTest, InferTypedParams) { "[SQLite] Type mismatch in column 0: expected INT64 but got DOUBLE")); } +TEST_F(SqliteReaderTest, BindByName) { + adbc_validation::StreamReader reader; + Handle<struct ArrowSchema> schema; + Handle<struct ArrowArray> batch; + + ASSERT_THAT(adbc_validation::MakeSchema(&schema.value, + { + {"@b", NANOARROW_TYPE_INT64}, + {"@a", NANOARROW_TYPE_INT64}, + }), + IsOkErrno()); + ASSERT_THAT((adbc_validation::MakeBatch<int64_t, int64_t>(&schema.value, &batch.value, + /*error=*/nullptr, {1}, {2})), + IsOkErrno()); + + ASSERT_NO_FATAL_FAILURE(Bind(&batch.value, &schema.value, true)); + ASSERT_NO_FATAL_FAILURE(Exec("SELECT @a, @b", /*infer_rows=*/2, &reader)); + ASSERT_EQ(2, reader.schema->n_children); + ASSERT_EQ(NANOARROW_TYPE_INT64, reader.fields[0].type); + ASSERT_EQ(NANOARROW_TYPE_INT64, reader.fields[1].type); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[0], {2})); + ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[1], {1})); +} + TEST_F(SqliteReaderTest, MultiValueParams) { // Regression test for apache/arrow-adbc#734 adbc_validation::StreamReader reader; diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c index 9eb65d48d..554bdaf20 100644 --- a/c/driver/sqlite/statement_reader.c +++ b/c/driver/sqlite/statement_reader.c @@ -35,7 +35,7 @@ #include "driver/common/utils.h" AdbcStatusCode InternalAdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, - struct AdbcError* error) { + bool bind_by_name, struct AdbcError* error) { int status = binder->params.get_schema(&binder->params, &binder->schema); if (status != 0) { const char* message = binder->params.get_last_error(&binder->params); @@ -61,6 +61,12 @@ AdbcStatusCode InternalAdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, binder->types = (enum ArrowType*)malloc(binder->schema.n_children * sizeof(enum ArrowType)); + if (bind_by_name) { + binder->param_indices = (int*)malloc(binder->schema.n_children * sizeof(int)); + // Lazily initialized below + memset(binder->param_indices, 0, binder->schema.n_children * sizeof(int)); + } + struct ArrowSchemaView view = {0}; for (int i = 0; i < binder->schema.n_children; i++) { status = ArrowSchemaViewInit(&view, binder->schema.children[i], &arrow_error); @@ -111,11 +117,12 @@ AdbcStatusCode InternalAdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, AdbcStatusCode InternalAdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, struct ArrowArrayStream* values, + bool bind_by_name, struct AdbcError* error) { InternalAdbcSqliteBinderRelease(binder); binder->params = *values; memset(values, 0, sizeof(*values)); - return InternalAdbcSqliteBinderSet(binder, error); + return InternalAdbcSqliteBinderSet(binder, bind_by_name, error); } #define SECONDS_PER_DAY 86400 @@ -330,9 +337,27 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, return ADBC_STATUS_INTERNAL; } + if (binder->param_indices != NULL && binder->param_indices[0] == 0) { + // Lazy initialize since we have the statement now + for (int i = 0; i < binder->schema.n_children; i++) { + binder->param_indices[i] = + sqlite3_bind_parameter_index(stmt, binder->schema.children[i]->name); + if (binder->param_indices[i] == 0) { + InternalAdbcSetError(error, "could not find parameter `%s`", + binder->schema.children[i]->name); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + } + for (int col = 0; col < binder->schema.n_children; col++) { + int bind_index = col + 1; + if (binder->param_indices != NULL) { + bind_index = binder->param_indices[col]; + } + if (ArrowArrayViewIsNull(binder->batch.children[col], binder->next_row)) { - status = sqlite3_bind_null(stmt, col + 1); + status = sqlite3_bind_null(stmt, bind_index); } else { switch (binder->types[col]) { case NANOARROW_TYPE_BINARY: @@ -341,7 +366,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, case NANOARROW_TYPE_BINARY_VIEW: { struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); - status = sqlite3_bind_blob(stmt, col + 1, value.data.as_char, + status = sqlite3_bind_blob(stmt, bind_index, value.data.as_char, (int)value.size_bytes, SQLITE_STATIC); break; } @@ -359,7 +384,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, col, value); return ADBC_STATUS_INVALID_ARGUMENT; } - status = sqlite3_bind_int64(stmt, col + 1, (int64_t)value); + status = sqlite3_bind_int64(stmt, bind_index, (int64_t)value); break; } case NANOARROW_TYPE_INT8: @@ -368,7 +393,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, case NANOARROW_TYPE_INT64: { int64_t value = ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); - status = sqlite3_bind_int64(stmt, col + 1, value); + status = sqlite3_bind_int64(stmt, bind_index, value); break; } case NANOARROW_TYPE_HALF_FLOAT: @@ -376,7 +401,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, case NANOARROW_TYPE_DOUBLE: { double value = ArrowArrayViewGetDoubleUnsafe(binder->batch.children[col], binder->next_row); - status = sqlite3_bind_double(stmt, col + 1, value); + status = sqlite3_bind_double(stmt, bind_index, value); break; } case NANOARROW_TYPE_STRING: @@ -384,7 +409,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, case NANOARROW_TYPE_STRING_VIEW: { struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); - status = sqlite3_bind_text(stmt, col + 1, value.data.as_char, + status = sqlite3_bind_text(stmt, bind_index, value.data.as_char, (int)value.size_bytes, SQLITE_STATIC); break; } @@ -393,11 +418,11 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); if (ArrowArrayViewIsNull(binder->batch.children[col]->dictionary, value_index)) { - status = sqlite3_bind_null(stmt, col + 1); + status = sqlite3_bind_null(stmt, bind_index); } else { struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe( binder->batch.children[col]->dictionary, value_index); - status = sqlite3_bind_text(stmt, col + 1, value.data.as_char, + status = sqlite3_bind_text(stmt, bind_index, value.data.as_char, (int)value.size_bytes, SQLITE_STATIC); } break; @@ -418,7 +443,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, RAISE_ADBC(ArrowDate32ToIsoString((int32_t)value, &tsstr, error)); // SQLITE_TRANSIENT ensures the value is copied during bind - status = sqlite3_bind_text(stmt, col + 1, tsstr, (int)strlen(tsstr), + status = sqlite3_bind_text(stmt, bind_index, tsstr, (int)strlen(tsstr), SQLITE_TRANSIENT); free(tsstr); @@ -436,7 +461,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, RAISE_ADBC(ArrowTimestampToIsoString(value, unit, &tsstr, error)); // SQLITE_TRANSIENT ensures the value is copied during bind - status = sqlite3_bind_text(stmt, col + 1, tsstr, (int)strlen(tsstr), + status = sqlite3_bind_text(stmt, bind_index, tsstr, (int)strlen(tsstr), SQLITE_TRANSIENT); free((char*)tsstr); break; @@ -449,8 +474,8 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, } if (status != SQLITE_OK) { - InternalAdbcSetError(error, "Failed to clear statement bindings: %s", - sqlite3_errmsg(conn)); + InternalAdbcSetError(error, "Failed to bind col %d to param %d: %s", col, + bind_index, sqlite3_errmsg(conn)); return ADBC_STATUS_INTERNAL; } } @@ -470,6 +495,9 @@ void InternalAdbcSqliteBinderRelease(struct AdbcSqliteBinder* binder) { if (binder->types) { free(binder->types); } + if (binder->param_indices) { + free(binder->param_indices); + } if (binder->array.release) { binder->array.release(&binder->array); } diff --git a/c/driver/sqlite/statement_reader.h b/c/driver/sqlite/statement_reader.h index a2851f9cb..99be50133 100644 --- a/c/driver/sqlite/statement_reader.h +++ b/c/driver/sqlite/statement_reader.h @@ -19,6 +19,8 @@ #pragma once +#include <stdbool.h> + #include <arrow-adbc/adbc.h> #include <nanoarrow/nanoarrow.h> #include <sqlite3.h> @@ -33,6 +35,7 @@ struct ADBC_EXPORT AdbcSqliteBinder { struct ArrowSchema schema; struct ArrowArrayStream params; enum ArrowType* types; + int* param_indices; // Scratch space struct ArrowArray array; @@ -43,6 +46,7 @@ struct ADBC_EXPORT AdbcSqliteBinder { ADBC_EXPORT AdbcStatusCode InternalAdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, struct ArrowArrayStream* values, + bool bind_by_name, struct AdbcError* error); ADBC_EXPORT AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py b/python/adbc_driver_manager/adbc_driver_manager/__init__.py index 61cd8bb1e..4ff1fee9c 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py +++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py @@ -116,6 +116,8 @@ class StatementOptions(enum.Enum): Not all drivers support all options. """ + #: Bind parameters by name instead of by position. + BIND_BY_NAME = "adbc.statement.bind_by_name" #: Enable incremental execution on ExecutePartitions. INCREMENTAL = "adbc.statement.exec.incremental" #: For bulk ingestion, whether to create or append to the table. diff --git a/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py b/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py index 545d76318..9457a7476 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py +++ b/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py @@ -60,7 +60,9 @@ class DbapiBackend(abc.ABC): ... @abc.abstractmethod - def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any: + def convert_executemany_parameters( + self, parameters: typing.Any + ) -> typing.Tuple[typing.Any, bool]: """Convert an arbitrary Python sequence into bind parameters. Parameters @@ -74,6 +76,9 @@ class DbapiBackend(abc.ABC): parameters : CapsuleType This should be an Arrow stream capsule or an object implementing the Arrow PyCapsule interface. + bind_by_name : bool + Whether the parameters should be bound by name (e.g. because they + contain a dictionary). See Also -------- @@ -107,7 +112,9 @@ class _NoOpBackend(DbapiBackend): status_code=_lib.AdbcStatusCode.INVALID_STATE, ) - def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any: + def convert_executemany_parameters( + self, parameters: typing.Any + ) -> typing.Tuple[typing.Any, bool]: raise _lib.ProgrammingError( "This API requires PyArrow or another suitable backend to be installed", status_code=_lib.AdbcStatusCode.INVALID_STATE, @@ -122,6 +129,29 @@ class _NoOpBackend(DbapiBackend): return handle +def param_iterable_to_dict(parameters: typing.Any) -> typing.Tuple[dict, bool]: + bind_by_name = False + cols = {} + for param in parameters: + if not cols: + # First iteration + if isinstance(param, dict): + bind_by_name = True + for k, v in param.items(): + cols[str(k)] = [] + else: + for col_idx, v in enumerate(param): + cols[str(col_idx)] = [] + + if isinstance(param, dict): + for k, v in param.items(): + cols[str(k)].append(v) + else: + for col_idx, v in enumerate(param): + cols[str(col_idx)].append(v) + return cols, bind_by_name + + _ALL_BACKENDS.append(_NoOpBackend()) try: @@ -129,18 +159,21 @@ try: class _PolarsBackend(DbapiBackend): def convert_bind_parameters(self, parameters: typing.Any) -> polars.DataFrame: - return polars.DataFrame( - {str(col_idx): x for col_idx, x in enumerate(parameters)}, - ) + if isinstance(parameters, dict): + return polars.DataFrame( + {str(k): v for k, v in parameters.items()}, + ) - def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any: return polars.DataFrame( - { - str(col_idx): x - for col_idx, x in enumerate(map(list, zip(*parameters))) - }, + {str(col_idx): v for col_idx, v in enumerate(parameters)}, ) + def convert_executemany_parameters( + self, parameters: typing.Any + ) -> typing.Tuple[typing.Any, bool]: + cols, bind_by_name = param_iterable_to_dict(parameters) + return polars.DataFrame(cols), bind_by_name + def import_array_stream( self, handle: _lib.ArrowArrayStreamHandle ) -> typing.Any: @@ -159,18 +192,20 @@ try: class _PyArrowBackend(DbapiBackend): def convert_bind_parameters(self, parameters: typing.Any) -> typing.Any: + if isinstance(parameters, dict): + return pyarrow.record_batch( + {str(k): [v] for k, v in parameters.items()}, + ) return pyarrow.record_batch( [[param_value] for param_value in parameters], names=[str(i) for i in range(len(parameters))], ) - def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any: - return pyarrow.RecordBatch.from_pydict( - { - str(col_idx): pyarrow.array(x) - for col_idx, x in enumerate(map(list, zip(*parameters))) - }, - ) + def convert_executemany_parameters( + self, parameters: typing.Any + ) -> typing.Tuple[typing.Any, bool]: + cols, bind_by_name = param_iterable_to_dict(parameters) + return pyarrow.RecordBatch.from_pydict(cols), bind_by_name def import_array_stream( self, handle: _lib.ArrowArrayStreamHandle diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index c4b35458b..4c3641dca 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -619,6 +619,7 @@ class Cursor(_Closeable): self._results: Optional["_RowIterator"] = None self._arraysize = 1 self._rowcount = -1 + self._bind_by_name = False if adbc_stmt_kwargs: self._stmt.set_options(**adbc_stmt_kwargs) @@ -711,6 +712,17 @@ class Cursor(_Closeable): rb = self._conn._backend.convert_bind_parameters(parameters) self._bind(rb) + if isinstance(parameters, dict) and not self._bind_by_name: + self._stmt.set_options( + **{adbc_driver_manager.StatementOptions.BIND_BY_NAME.value: "true"} + ) + self._bind_by_name = True + elif not isinstance(parameters, dict) and self._bind_by_name: + self._stmt.set_options( + **{adbc_driver_manager.StatementOptions.BIND_BY_NAME.value: "false"} + ) + self._bind_by_name = False + def execute(self, operation: Union[bytes, str], parameters=None) -> None: """ Execute a query. @@ -721,10 +733,17 @@ class Cursor(_Closeable): The query to execute. Pass SQL queries as strings, (serialized) Substrait plans as bytes. parameters - Parameters to bind. Can be a Python sequence (to provide - a single set of parameters), or an Arrow record batch, - table, or record batch reader (to provide multiple - parameters, which will each be bound in turn). + Parameters to bind. Can be a Python sequence (to bind a single + set of parameters), a Python dictionary (to bind a single set of + parameters by name instead of position), or an Arrow record batch, + table, or record batch reader (to provide multiple parameters, + which will each be bound in turn). + + To bind by name when providing Arrow data, explicitly toggle the + statement option "adbc.statement.bind_by_name". + + Note that providing a list of tuples is not supported (this mode + of usage is deprecated in DBAPI-2.0; use executemany() instead). """ self._clear() self._prepare_execute(operation, parameters) @@ -763,15 +782,26 @@ class Cursor(_Closeable): self._stmt.set_sql_query(operation) self._stmt.prepare() + bind_by_name = None if _is_arrow_data(seq_of_parameters): arrow_parameters = seq_of_parameters elif seq_of_parameters: - arrow_parameters = self._conn._backend.convert_executemany_parameters( - seq_of_parameters + arrow_parameters, bind_by_name = ( + self._conn._backend.convert_executemany_parameters(seq_of_parameters) ) else: arrow_parameters = None + if bind_by_name is not None and bind_by_name != self._bind_by_name: + self._stmt.set_options( + **{ + adbc_driver_manager.StatementOptions.BIND_BY_NAME.value: ( + "true" if bind_by_name else "false" + ), + } + ) + self._bind_by_name = bind_by_name + if arrow_parameters is not None: self._bind(arrow_parameters) elif seq_of_parameters is not None: diff --git a/python/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/tests/test_dbapi.py index 527a263a6..72699b13e 100644 --- a/python/adbc_driver_manager/tests/test_dbapi.py +++ b/python/adbc_driver_manager/tests/test_dbapi.py @@ -314,6 +314,42 @@ def test_execute_parameters(sqlite, parameters): assert cur.fetchall() == [(2.0, 2)] +@pytest.mark.sqlite +def test_execute_parameters_name(sqlite): + with sqlite.cursor() as cur: + cur.execute("SELECT @a + 1, @b", {"@b": 2, "@a": 1}) + assert cur.fetchall() == [(2, 2)] + + # Ensure the state of the cursor isn't affected + cur.execute("SELECT ?2 + 1, ?1", [2, 1]) + assert cur.fetchall() == [(2, 2)] + + cur.execute("SELECT @a + 1, @b + @b", {"@b": 2, "@a": 1}) + assert cur.fetchall() == [(2, 4)] + + data = pyarrow.record_batch([[1.0], [2]], names=["float", "int"]) + cur.adbc_ingest("ingest_tester", data) + cur.execute("SELECT * FROM ingest_tester") + assert cur.fetchall() == [(1.0, 2)] + + +@pytest.mark.sqlite +def test_executemany_parameters_name(sqlite): + with sqlite.cursor() as cur: + cur.execute("CREATE TABLE executemany_params (a, b)") + + cur.executemany( + "INSERT INTO executemany_params VALUES (@a, @b)", + [{"@b": 2, "@a": 1}, {"@b": 3, "@a": 2}], + ) + cur.executemany( + "INSERT INTO executemany_params VALUES (?, ?)", [(3, 4), (4, 5)] + ) + + cur.execute("SELECT * FROM executemany_params ORDER BY a ASC") + assert cur.fetchall() == [(1, 2), (2, 3), (3, 4), (4, 5)] + + @pytest.mark.sqlite @pytest.mark.parametrize( "parameters", diff --git a/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py b/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py index 6b36a2189..da86fd3a1 100644 --- a/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py +++ b/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py @@ -152,6 +152,85 @@ def test_query_executemany_parameters(sqlite: dbapi.Connection, parameters) -> N ) +def test_execute_parameters_name(sqlite): + with sqlite.cursor() as cursor: + cursor.execute("SELECT @a + 1, @b", {"@b": 2, "@a": 1}) + df = cursor.fetch_polars() + polars.testing.assert_frame_equal( + df, + polars.DataFrame( + { + "@a + 1": [2], + "@b": [2], + } + ), + ) + + # Ensure the state of the cursor isn't affected + cursor.execute("SELECT ?2 + 1, ?1", [2, 1]) + df = cursor.fetch_polars() + polars.testing.assert_frame_equal( + df, + polars.DataFrame( + { + "?2 + 1": [2], + "?1": [2], + } + ), + ) + + cursor.execute("SELECT @a + 1, @b + @b", {"@b": 2, "@a": 1}) + df = cursor.fetch_polars() + polars.testing.assert_frame_equal( + df, + polars.DataFrame( + { + "@a + 1": [2], + "@b + @b": [4], + } + ), + ) + + data = polars.DataFrame({"float": [1.0], "int": [2]}) + cursor.adbc_ingest("ingest_tester", data) + cursor.execute("SELECT * FROM ingest_tester") + df = cursor.fetch_polars() + polars.testing.assert_frame_equal( + df, + polars.DataFrame( + { + "float": [1.0], + "int": [2], + } + ), + ) + + +def test_executemany_parameters_name(sqlite): + with sqlite.cursor() as cursor: + cursor.execute("CREATE TABLE executemany_params (a, b)") + + cursor.executemany( + "INSERT INTO executemany_params VALUES (@a, @b)", + [{"@b": 2, "@a": 1}, {"@b": 3, "@a": 2}], + ) + cursor.executemany( + "INSERT INTO executemany_params VALUES (?, ?)", [(3, 4), (4, 5)] + ) + + cursor.execute("SELECT * FROM executemany_params ORDER BY a ASC") + df = cursor.fetch_polars() + polars.testing.assert_frame_equal( + df, + polars.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [2, 3, 4, 5], + } + ), + ) + + @pytest.mark.parametrize( "parameters", [