This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new d44e3616d feat(c/driver/sqlite,python/adbc_driver_manager): bind 
params by name (#3362)
d44e3616d is described below

commit d44e3616dd5182f4278801998e3838134c98994b
Author: David Li <li.david...@gmail.com>
AuthorDate: Sat Sep 6 15:55:03 2025 +0900

    feat(c/driver/sqlite,python/adbc_driver_manager): bind params by name 
(#3362)
    
    Closes #3262.
---
 c/driver/sqlite/sqlite.cc                          | 17 +++--
 c/driver/sqlite/sqlite_test.cc                     | 38 +++++++++--
 c/driver/sqlite/statement_reader.c                 | 56 +++++++++++----
 c/driver/sqlite/statement_reader.h                 |  4 ++
 .../adbc_driver_manager/__init__.py                |  2 +
 .../adbc_driver_manager/_dbapi_backend.py          | 69 ++++++++++++++-----
 .../adbc_driver_manager/dbapi.py                   | 42 ++++++++++--
 python/adbc_driver_manager/tests/test_dbapi.py     | 36 ++++++++++
 .../tests/test_dbapi_polars_nopyarrow.py           | 79 ++++++++++++++++++++++
 9 files changed, 295 insertions(+), 48 deletions(-)

diff --git a/c/driver/sqlite/sqlite.cc b/c/driver/sqlite/sqlite.cc
index 1a47a4f13..dc3a7de43 100644
--- a/c/driver/sqlite/sqlite.cc
+++ b/c/driver/sqlite/sqlite.cc
@@ -51,6 +51,7 @@ constexpr std::string_view 
kConnectionOptionLoadExtensionEntrypoint =
     "adbc.sqlite.load_extension.entrypoint";
 /// The batch size for query results (and for initial type inference)
 constexpr std::string_view kStatementOptionBatchRows = 
"adbc.sqlite.query.batch_rows";
+constexpr std::string_view kStatementOptionBindByName = 
"adbc.statement.bind_by_name";
 
 std::string_view GetColumnText(sqlite3_stmt* stmt, int index) {
   return {
@@ -763,11 +764,11 @@ class SqliteStatement : public 
driver::Statement<SqliteStatement> {
  public:
   [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[SQLite]";
 
-  Status BindImpl() {
+  Status BindImpl(bool ingest) {
     if (bind_parameters_.release) {
       struct AdbcError error = ADBC_ERROR_INIT;
-      if (AdbcStatusCode code =
-              InternalAdbcSqliteBinderSetArrayStream(&binder_, 
&bind_parameters_, &error);
+      if (AdbcStatusCode code = InternalAdbcSqliteBinderSetArrayStream(
+              &binder_, &bind_parameters_, !ingest && bind_by_name_, &error);
           code != ADBC_STATUS_OK) {
         return Status::FromAdbc(code, error);
       }
@@ -776,7 +777,7 @@ class SqliteStatement : public 
driver::Statement<SqliteStatement> {
   }
 
   Result<int64_t> ExecuteIngestImpl(IngestState& state) {
-    UNWRAP_STATUS(BindImpl());
+    UNWRAP_STATUS(BindImpl(true));
     if (!binder_.schema.release) {
       return status::InvalidState("must Bind() before bulk ingestion");
     }
@@ -975,7 +976,7 @@ class SqliteStatement : public 
driver::Statement<SqliteStatement> {
 
   Result<int64_t> ExecuteQueryImpl(ArrowArrayStream* stream) {
     struct AdbcError error = ADBC_ERROR_INIT;
-    UNWRAP_STATUS(BindImpl());
+    UNWRAP_STATUS(BindImpl(false));
 
     const int64_t expected = sqlite3_bind_parameter_count(stmt_);
     const int64_t actual = binder_.schema.n_children;
@@ -1003,7 +1004,7 @@ class SqliteStatement : public 
driver::Statement<SqliteStatement> {
   }
 
   Result<int64_t> ExecuteUpdateImpl() {
-    UNWRAP_STATUS(BindImpl());
+    UNWRAP_STATUS(BindImpl(false));
 
     const int64_t expected = sqlite3_bind_parameter_count(stmt_);
     const int64_t actual = binder_.schema.n_children;
@@ -1143,11 +1144,15 @@ class SqliteStatement : public 
driver::Statement<SqliteStatement> {
       }
       batch_size_ = static_cast<int>(batch_size);
       return status::Ok();
+    } else if (key == kStatementOptionBindByName) {
+      UNWRAP_RESULT(bind_by_name_, value.AsBool());
+      return status::Ok();
     }
     return Base::SetOptionImpl(key, std::move(value));
   }
 
   int batch_size_ = 1024;
+  bool bind_by_name_ = false;
   AdbcSqliteBinder binder_;
   sqlite3* conn_ = nullptr;
   sqlite3_stmt* stmt_ = nullptr;
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index 62f15c690..f270f5059 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -454,17 +454,19 @@ class SqliteReaderTest : public ::testing::Test {
     stmt = nullptr;
   }
 
-  void Bind(struct ArrowArray* batch, struct ArrowSchema* schema) {
+  void Bind(struct ArrowArray* batch, struct ArrowSchema* schema,
+            bool bind_by_name = false) {
     Handle<struct ArrowArrayStream> stream;
     struct ArrowArray batch_internal = *batch;
     batch->release = nullptr;
     adbc_validation::MakeStream(&stream.value, schema, {batch_internal});
-    ASSERT_NO_FATAL_FAILURE(Bind(&stream.value));
+    ASSERT_NO_FATAL_FAILURE(Bind(&stream.value, bind_by_name));
   }
 
-  void Bind(struct ArrowArrayStream* stream) {
-    ASSERT_THAT(InternalAdbcSqliteBinderSetArrayStream(&binder, stream, 
&error),
-                IsOkStatus(&error));
+  void Bind(struct ArrowArrayStream* stream, bool bind_by_name = false) {
+    ASSERT_THAT(
+        InternalAdbcSqliteBinderSetArrayStream(&binder, stream, bind_by_name, 
&error),
+        IsOkStatus(&error));
   }
 
   void ExecSelect(const std::string& values, size_t infer_rows,
@@ -826,6 +828,32 @@ TEST_F(SqliteReaderTest, InferTypedParams) {
                   "[SQLite] Type mismatch in column 0: expected INT64 but got 
DOUBLE"));
 }
 
+TEST_F(SqliteReaderTest, BindByName) {
+  adbc_validation::StreamReader reader;
+  Handle<struct ArrowSchema> schema;
+  Handle<struct ArrowArray> batch;
+
+  ASSERT_THAT(adbc_validation::MakeSchema(&schema.value,
+                                          {
+                                              {"@b", NANOARROW_TYPE_INT64},
+                                              {"@a", NANOARROW_TYPE_INT64},
+                                          }),
+              IsOkErrno());
+  ASSERT_THAT((adbc_validation::MakeBatch<int64_t, int64_t>(&schema.value, 
&batch.value,
+                                                            /*error=*/nullptr, 
{1}, {2})),
+              IsOkErrno());
+
+  ASSERT_NO_FATAL_FAILURE(Bind(&batch.value, &schema.value, true));
+  ASSERT_NO_FATAL_FAILURE(Exec("SELECT @a, @b", /*infer_rows=*/2, &reader));
+  ASSERT_EQ(2, reader.schema->n_children);
+  ASSERT_EQ(NANOARROW_TYPE_INT64, reader.fields[0].type);
+  ASSERT_EQ(NANOARROW_TYPE_INT64, reader.fields[1].type);
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  
ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[0], 
{2}));
+  
ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[1], 
{1}));
+}
+
 TEST_F(SqliteReaderTest, MultiValueParams) {
   // Regression test for apache/arrow-adbc#734
   adbc_validation::StreamReader reader;
diff --git a/c/driver/sqlite/statement_reader.c 
b/c/driver/sqlite/statement_reader.c
index 9eb65d48d..554bdaf20 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -35,7 +35,7 @@
 #include "driver/common/utils.h"
 
 AdbcStatusCode InternalAdbcSqliteBinderSet(struct AdbcSqliteBinder* binder,
-                                           struct AdbcError* error) {
+                                           bool bind_by_name, struct 
AdbcError* error) {
   int status = binder->params.get_schema(&binder->params, &binder->schema);
   if (status != 0) {
     const char* message = binder->params.get_last_error(&binder->params);
@@ -61,6 +61,12 @@ AdbcStatusCode InternalAdbcSqliteBinderSet(struct 
AdbcSqliteBinder* binder,
   binder->types =
       (enum ArrowType*)malloc(binder->schema.n_children * sizeof(enum 
ArrowType));
 
+  if (bind_by_name) {
+    binder->param_indices = (int*)malloc(binder->schema.n_children * 
sizeof(int));
+    // Lazily initialized below
+    memset(binder->param_indices, 0, binder->schema.n_children * sizeof(int));
+  }
+
   struct ArrowSchemaView view = {0};
   for (int i = 0; i < binder->schema.n_children; i++) {
     status = ArrowSchemaViewInit(&view, binder->schema.children[i], 
&arrow_error);
@@ -111,11 +117,12 @@ AdbcStatusCode InternalAdbcSqliteBinderSet(struct 
AdbcSqliteBinder* binder,
 
 AdbcStatusCode InternalAdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* 
binder,
                                                       struct ArrowArrayStream* 
values,
+                                                      bool bind_by_name,
                                                       struct AdbcError* error) 
{
   InternalAdbcSqliteBinderRelease(binder);
   binder->params = *values;
   memset(values, 0, sizeof(*values));
-  return InternalAdbcSqliteBinderSet(binder, error);
+  return InternalAdbcSqliteBinderSet(binder, bind_by_name, error);
 }
 
 #define SECONDS_PER_DAY 86400
@@ -330,9 +337,27 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
     return ADBC_STATUS_INTERNAL;
   }
 
+  if (binder->param_indices != NULL && binder->param_indices[0] == 0) {
+    // Lazy initialize since we have the statement now
+    for (int i = 0; i < binder->schema.n_children; i++) {
+      binder->param_indices[i] =
+          sqlite3_bind_parameter_index(stmt, binder->schema.children[i]->name);
+      if (binder->param_indices[i] == 0) {
+        InternalAdbcSetError(error, "could not find parameter `%s`",
+                             binder->schema.children[i]->name);
+        return ADBC_STATUS_INVALID_ARGUMENT;
+      }
+    }
+  }
+
   for (int col = 0; col < binder->schema.n_children; col++) {
+    int bind_index = col + 1;
+    if (binder->param_indices != NULL) {
+      bind_index = binder->param_indices[col];
+    }
+
     if (ArrowArrayViewIsNull(binder->batch.children[col], binder->next_row)) {
-      status = sqlite3_bind_null(stmt, col + 1);
+      status = sqlite3_bind_null(stmt, bind_index);
     } else {
       switch (binder->types[col]) {
         case NANOARROW_TYPE_BINARY:
@@ -341,7 +366,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
         case NANOARROW_TYPE_BINARY_VIEW: {
           struct ArrowBufferView value =
               ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], 
binder->next_row);
-          status = sqlite3_bind_blob(stmt, col + 1, value.data.as_char,
+          status = sqlite3_bind_blob(stmt, bind_index, value.data.as_char,
                                      (int)value.size_bytes, SQLITE_STATIC);
           break;
         }
@@ -359,7 +384,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
                                  col, value);
             return ADBC_STATUS_INVALID_ARGUMENT;
           }
-          status = sqlite3_bind_int64(stmt, col + 1, (int64_t)value);
+          status = sqlite3_bind_int64(stmt, bind_index, (int64_t)value);
           break;
         }
         case NANOARROW_TYPE_INT8:
@@ -368,7 +393,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
         case NANOARROW_TYPE_INT64: {
           int64_t value =
               ArrowArrayViewGetIntUnsafe(binder->batch.children[col], 
binder->next_row);
-          status = sqlite3_bind_int64(stmt, col + 1, value);
+          status = sqlite3_bind_int64(stmt, bind_index, value);
           break;
         }
         case NANOARROW_TYPE_HALF_FLOAT:
@@ -376,7 +401,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
         case NANOARROW_TYPE_DOUBLE: {
           double value = 
ArrowArrayViewGetDoubleUnsafe(binder->batch.children[col],
                                                        binder->next_row);
-          status = sqlite3_bind_double(stmt, col + 1, value);
+          status = sqlite3_bind_double(stmt, bind_index, value);
           break;
         }
         case NANOARROW_TYPE_STRING:
@@ -384,7 +409,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
         case NANOARROW_TYPE_STRING_VIEW: {
           struct ArrowBufferView value =
               ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], 
binder->next_row);
-          status = sqlite3_bind_text(stmt, col + 1, value.data.as_char,
+          status = sqlite3_bind_text(stmt, bind_index, value.data.as_char,
                                      (int)value.size_bytes, SQLITE_STATIC);
           break;
         }
@@ -393,11 +418,11 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
               ArrowArrayViewGetIntUnsafe(binder->batch.children[col], 
binder->next_row);
           if (ArrowArrayViewIsNull(binder->batch.children[col]->dictionary,
                                    value_index)) {
-            status = sqlite3_bind_null(stmt, col + 1);
+            status = sqlite3_bind_null(stmt, bind_index);
           } else {
             struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(
                 binder->batch.children[col]->dictionary, value_index);
-            status = sqlite3_bind_text(stmt, col + 1, value.data.as_char,
+            status = sqlite3_bind_text(stmt, bind_index, value.data.as_char,
                                        (int)value.size_bytes, SQLITE_STATIC);
           }
           break;
@@ -418,7 +443,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
 
           RAISE_ADBC(ArrowDate32ToIsoString((int32_t)value, &tsstr, error));
           // SQLITE_TRANSIENT ensures the value is copied during bind
-          status = sqlite3_bind_text(stmt, col + 1, tsstr, (int)strlen(tsstr),
+          status = sqlite3_bind_text(stmt, bind_index, tsstr, 
(int)strlen(tsstr),
                                      SQLITE_TRANSIENT);
 
           free(tsstr);
@@ -436,7 +461,7 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
           RAISE_ADBC(ArrowTimestampToIsoString(value, unit, &tsstr, error));
 
           // SQLITE_TRANSIENT ensures the value is copied during bind
-          status = sqlite3_bind_text(stmt, col + 1, tsstr, (int)strlen(tsstr),
+          status = sqlite3_bind_text(stmt, bind_index, tsstr, 
(int)strlen(tsstr),
                                      SQLITE_TRANSIENT);
           free((char*)tsstr);
           break;
@@ -449,8 +474,8 @@ AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder,
     }
 
     if (status != SQLITE_OK) {
-      InternalAdbcSetError(error, "Failed to clear statement bindings: %s",
-                           sqlite3_errmsg(conn));
+      InternalAdbcSetError(error, "Failed to bind col %d to param %d: %s", col,
+                           bind_index, sqlite3_errmsg(conn));
       return ADBC_STATUS_INTERNAL;
     }
   }
@@ -470,6 +495,9 @@ void InternalAdbcSqliteBinderRelease(struct 
AdbcSqliteBinder* binder) {
   if (binder->types) {
     free(binder->types);
   }
+  if (binder->param_indices) {
+    free(binder->param_indices);
+  }
   if (binder->array.release) {
     binder->array.release(&binder->array);
   }
diff --git a/c/driver/sqlite/statement_reader.h 
b/c/driver/sqlite/statement_reader.h
index a2851f9cb..99be50133 100644
--- a/c/driver/sqlite/statement_reader.h
+++ b/c/driver/sqlite/statement_reader.h
@@ -19,6 +19,8 @@
 
 #pragma once
 
+#include <stdbool.h>
+
 #include <arrow-adbc/adbc.h>
 #include <nanoarrow/nanoarrow.h>
 #include <sqlite3.h>
@@ -33,6 +35,7 @@ struct ADBC_EXPORT AdbcSqliteBinder {
   struct ArrowSchema schema;
   struct ArrowArrayStream params;
   enum ArrowType* types;
+  int* param_indices;
 
   // Scratch space
   struct ArrowArray array;
@@ -43,6 +46,7 @@ struct ADBC_EXPORT AdbcSqliteBinder {
 ADBC_EXPORT
 AdbcStatusCode InternalAdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* 
binder,
                                                       struct ArrowArrayStream* 
values,
+                                                      bool bind_by_name,
                                                       struct AdbcError* error);
 ADBC_EXPORT
 AdbcStatusCode InternalAdbcSqliteBinderBindNext(struct AdbcSqliteBinder* 
binder,
diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py 
b/python/adbc_driver_manager/adbc_driver_manager/__init__.py
index 61cd8bb1e..4ff1fee9c 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py
@@ -116,6 +116,8 @@ class StatementOptions(enum.Enum):
     Not all drivers support all options.
     """
 
+    #: Bind parameters by name instead of by position.
+    BIND_BY_NAME = "adbc.statement.bind_by_name"
     #: Enable incremental execution on ExecutePartitions.
     INCREMENTAL = "adbc.statement.exec.incremental"
     #: For bulk ingestion, whether to create or append to the table.
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py 
b/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py
index 545d76318..9457a7476 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py
@@ -60,7 +60,9 @@ class DbapiBackend(abc.ABC):
         ...
 
     @abc.abstractmethod
-    def convert_executemany_parameters(self, parameters: typing.Any) -> 
typing.Any:
+    def convert_executemany_parameters(
+        self, parameters: typing.Any
+    ) -> typing.Tuple[typing.Any, bool]:
         """Convert an arbitrary Python sequence into bind parameters.
 
         Parameters
@@ -74,6 +76,9 @@ class DbapiBackend(abc.ABC):
         parameters : CapsuleType
             This should be an Arrow stream capsule or an object implementing
             the Arrow PyCapsule interface.
+        bind_by_name : bool
+            Whether the parameters should be bound by name (e.g. because they
+            contain a dictionary).
 
         See Also
         --------
@@ -107,7 +112,9 @@ class _NoOpBackend(DbapiBackend):
             status_code=_lib.AdbcStatusCode.INVALID_STATE,
         )
 
-    def convert_executemany_parameters(self, parameters: typing.Any) -> 
typing.Any:
+    def convert_executemany_parameters(
+        self, parameters: typing.Any
+    ) -> typing.Tuple[typing.Any, bool]:
         raise _lib.ProgrammingError(
             "This API requires PyArrow or another suitable backend to be 
installed",
             status_code=_lib.AdbcStatusCode.INVALID_STATE,
@@ -122,6 +129,29 @@ class _NoOpBackend(DbapiBackend):
         return handle
 
 
+def param_iterable_to_dict(parameters: typing.Any) -> typing.Tuple[dict, bool]:
+    bind_by_name = False
+    cols = {}
+    for param in parameters:
+        if not cols:
+            # First iteration
+            if isinstance(param, dict):
+                bind_by_name = True
+                for k, v in param.items():
+                    cols[str(k)] = []
+            else:
+                for col_idx, v in enumerate(param):
+                    cols[str(col_idx)] = []
+
+        if isinstance(param, dict):
+            for k, v in param.items():
+                cols[str(k)].append(v)
+        else:
+            for col_idx, v in enumerate(param):
+                cols[str(col_idx)].append(v)
+    return cols, bind_by_name
+
+
 _ALL_BACKENDS.append(_NoOpBackend())
 
 try:
@@ -129,18 +159,21 @@ try:
 
     class _PolarsBackend(DbapiBackend):
         def convert_bind_parameters(self, parameters: typing.Any) -> 
polars.DataFrame:
-            return polars.DataFrame(
-                {str(col_idx): x for col_idx, x in enumerate(parameters)},
-            )
+            if isinstance(parameters, dict):
+                return polars.DataFrame(
+                    {str(k): v for k, v in parameters.items()},
+                )
 
-        def convert_executemany_parameters(self, parameters: typing.Any) -> 
typing.Any:
             return polars.DataFrame(
-                {
-                    str(col_idx): x
-                    for col_idx, x in enumerate(map(list, zip(*parameters)))
-                },
+                {str(col_idx): v for col_idx, v in enumerate(parameters)},
             )
 
+        def convert_executemany_parameters(
+            self, parameters: typing.Any
+        ) -> typing.Tuple[typing.Any, bool]:
+            cols, bind_by_name = param_iterable_to_dict(parameters)
+            return polars.DataFrame(cols), bind_by_name
+
         def import_array_stream(
             self, handle: _lib.ArrowArrayStreamHandle
         ) -> typing.Any:
@@ -159,18 +192,20 @@ try:
 
     class _PyArrowBackend(DbapiBackend):
         def convert_bind_parameters(self, parameters: typing.Any) -> 
typing.Any:
+            if isinstance(parameters, dict):
+                return pyarrow.record_batch(
+                    {str(k): [v] for k, v in parameters.items()},
+                )
             return pyarrow.record_batch(
                 [[param_value] for param_value in parameters],
                 names=[str(i) for i in range(len(parameters))],
             )
 
-        def convert_executemany_parameters(self, parameters: typing.Any) -> 
typing.Any:
-            return pyarrow.RecordBatch.from_pydict(
-                {
-                    str(col_idx): pyarrow.array(x)
-                    for col_idx, x in enumerate(map(list, zip(*parameters)))
-                },
-            )
+        def convert_executemany_parameters(
+            self, parameters: typing.Any
+        ) -> typing.Tuple[typing.Any, bool]:
+            cols, bind_by_name = param_iterable_to_dict(parameters)
+            return pyarrow.RecordBatch.from_pydict(cols), bind_by_name
 
         def import_array_stream(
             self, handle: _lib.ArrowArrayStreamHandle
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index c4b35458b..4c3641dca 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -619,6 +619,7 @@ class Cursor(_Closeable):
         self._results: Optional["_RowIterator"] = None
         self._arraysize = 1
         self._rowcount = -1
+        self._bind_by_name = False
 
         if adbc_stmt_kwargs:
             self._stmt.set_options(**adbc_stmt_kwargs)
@@ -711,6 +712,17 @@ class Cursor(_Closeable):
             rb = self._conn._backend.convert_bind_parameters(parameters)
             self._bind(rb)
 
+            if isinstance(parameters, dict) and not self._bind_by_name:
+                self._stmt.set_options(
+                    
**{adbc_driver_manager.StatementOptions.BIND_BY_NAME.value: "true"}
+                )
+                self._bind_by_name = True
+            elif not isinstance(parameters, dict) and self._bind_by_name:
+                self._stmt.set_options(
+                    
**{adbc_driver_manager.StatementOptions.BIND_BY_NAME.value: "false"}
+                )
+                self._bind_by_name = False
+
     def execute(self, operation: Union[bytes, str], parameters=None) -> None:
         """
         Execute a query.
@@ -721,10 +733,17 @@ class Cursor(_Closeable):
             The query to execute.  Pass SQL queries as strings,
             (serialized) Substrait plans as bytes.
         parameters
-            Parameters to bind.  Can be a Python sequence (to provide
-            a single set of parameters), or an Arrow record batch,
-            table, or record batch reader (to provide multiple
-            parameters, which will each be bound in turn).
+            Parameters to bind.  Can be a Python sequence (to bind a single
+            set of parameters), a Python dictionary (to bind a single set of
+            parameters by name instead of position), or an Arrow record batch,
+            table, or record batch reader (to provide multiple parameters,
+            which will each be bound in turn).
+
+            To bind by name when providing Arrow data, explicitly toggle the
+            statement option "adbc.statement.bind_by_name".
+
+            Note that providing a list of tuples is not supported (this mode
+            of usage is deprecated in DBAPI-2.0; use executemany() instead).
         """
         self._clear()
         self._prepare_execute(operation, parameters)
@@ -763,15 +782,26 @@ class Cursor(_Closeable):
             self._stmt.set_sql_query(operation)
             self._stmt.prepare()
 
+        bind_by_name = None
         if _is_arrow_data(seq_of_parameters):
             arrow_parameters = seq_of_parameters
         elif seq_of_parameters:
-            arrow_parameters = 
self._conn._backend.convert_executemany_parameters(
-                seq_of_parameters
+            arrow_parameters, bind_by_name = (
+                
self._conn._backend.convert_executemany_parameters(seq_of_parameters)
             )
         else:
             arrow_parameters = None
 
+        if bind_by_name is not None and bind_by_name != self._bind_by_name:
+            self._stmt.set_options(
+                **{
+                    adbc_driver_manager.StatementOptions.BIND_BY_NAME.value: (
+                        "true" if bind_by_name else "false"
+                    ),
+                }
+            )
+            self._bind_by_name = bind_by_name
+
         if arrow_parameters is not None:
             self._bind(arrow_parameters)
         elif seq_of_parameters is not None:
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py 
b/python/adbc_driver_manager/tests/test_dbapi.py
index 527a263a6..72699b13e 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -314,6 +314,42 @@ def test_execute_parameters(sqlite, parameters):
         assert cur.fetchall() == [(2.0, 2)]
 
 
+@pytest.mark.sqlite
+def test_execute_parameters_name(sqlite):
+    with sqlite.cursor() as cur:
+        cur.execute("SELECT @a + 1, @b", {"@b": 2, "@a": 1})
+        assert cur.fetchall() == [(2, 2)]
+
+        # Ensure the state of the cursor isn't affected
+        cur.execute("SELECT ?2 + 1, ?1", [2, 1])
+        assert cur.fetchall() == [(2, 2)]
+
+        cur.execute("SELECT @a + 1, @b + @b", {"@b": 2, "@a": 1})
+        assert cur.fetchall() == [(2, 4)]
+
+        data = pyarrow.record_batch([[1.0], [2]], names=["float", "int"])
+        cur.adbc_ingest("ingest_tester", data)
+        cur.execute("SELECT * FROM ingest_tester")
+        assert cur.fetchall() == [(1.0, 2)]
+
+
+@pytest.mark.sqlite
+def test_executemany_parameters_name(sqlite):
+    with sqlite.cursor() as cur:
+        cur.execute("CREATE TABLE executemany_params (a, b)")
+
+        cur.executemany(
+            "INSERT INTO executemany_params VALUES (@a, @b)",
+            [{"@b": 2, "@a": 1}, {"@b": 3, "@a": 2}],
+        )
+        cur.executemany(
+            "INSERT INTO executemany_params VALUES (?, ?)", [(3, 4), (4, 5)]
+        )
+
+        cur.execute("SELECT * FROM executemany_params ORDER BY a ASC")
+        assert cur.fetchall() == [(1, 2), (2, 3), (3, 4), (4, 5)]
+
+
 @pytest.mark.sqlite
 @pytest.mark.parametrize(
     "parameters",
diff --git a/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py 
b/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py
index 6b36a2189..da86fd3a1 100644
--- a/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py
+++ b/python/adbc_driver_manager/tests/test_dbapi_polars_nopyarrow.py
@@ -152,6 +152,85 @@ def test_query_executemany_parameters(sqlite: 
dbapi.Connection, parameters) -> N
         )
 
 
+def test_execute_parameters_name(sqlite):
+    with sqlite.cursor() as cursor:
+        cursor.execute("SELECT @a + 1, @b", {"@b": 2, "@a": 1})
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "@a + 1": [2],
+                    "@b": [2],
+                }
+            ),
+        )
+
+        # Ensure the state of the cursor isn't affected
+        cursor.execute("SELECT ?2 + 1, ?1", [2, 1])
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "?2 + 1": [2],
+                    "?1": [2],
+                }
+            ),
+        )
+
+        cursor.execute("SELECT @a + 1, @b + @b", {"@b": 2, "@a": 1})
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "@a + 1": [2],
+                    "@b + @b": [4],
+                }
+            ),
+        )
+
+        data = polars.DataFrame({"float": [1.0], "int": [2]})
+        cursor.adbc_ingest("ingest_tester", data)
+        cursor.execute("SELECT * FROM ingest_tester")
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "float": [1.0],
+                    "int": [2],
+                }
+            ),
+        )
+
+
+def test_executemany_parameters_name(sqlite):
+    with sqlite.cursor() as cursor:
+        cursor.execute("CREATE TABLE executemany_params (a, b)")
+
+        cursor.executemany(
+            "INSERT INTO executemany_params VALUES (@a, @b)",
+            [{"@b": 2, "@a": 1}, {"@b": 3, "@a": 2}],
+        )
+        cursor.executemany(
+            "INSERT INTO executemany_params VALUES (?, ?)", [(3, 4), (4, 5)]
+        )
+
+        cursor.execute("SELECT * FROM executemany_params ORDER BY a ASC")
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "a": [1, 2, 3, 4],
+                    "b": [2, 3, 4, 5],
+                }
+            ),
+        )
+
+
 @pytest.mark.parametrize(
     "parameters",
     [

Reply via email to