This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 19a30aea feat(python/adbc_driver_manager): add fetch_record_batch
(#989)
19a30aea is described below
commit 19a30aea784dfdce1fed506604f0272a222a3294
Author: Solomon Choe <[email protected]>
AuthorDate: Wed Aug 23 11:25:57 2023 -0700
feat(python/adbc_driver_manager): add fetch_record_batch (#989)
Fixes #968
---------
Co-authored-by: David Li <[email protected]>
---
.../adbc_driver_manager/dbapi.py | 26 ++++++++++++++++++----
python/adbc_driver_manager/tests/test_dbapi.py | 20 +++++++++++++++++
2 files changed, 42 insertions(+), 4 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 31e4392a..60bc2d1b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -926,6 +926,24 @@ class Cursor(_Closeable):
)
return self._results.fetch_df()
+ def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
+ """
+ Fetch the result as a PyArrow RecordBatchReader.
+
+ This implements a similar API as DuckDB:
+
https://duckdb.org/docs/guides/python/export_arrow.html#export-as-a-recordbatchreader
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ if self._results is None:
+ raise ProgrammingError(
+ "Cannot fetch_record_batch() before execute()",
+ status_code=_lib.AdbcStatusCode.INVALID_STATE,
+ )
+ return self._results._reader
+
# ----------------------------------------------------------
# Utilities
@@ -973,7 +991,7 @@ class _RowIterator(_Closeable):
self.rownumber += 1
return row
- def fetchmany(self, size: int):
+ def fetchmany(self, size: int) -> List[tuple]:
rows = []
for _ in range(size):
row = self.fetchone()
@@ -982,7 +1000,7 @@ class _RowIterator(_Closeable):
rows.append(row)
return rows
- def fetchall(self):
+ def fetchall(self) -> List[tuple]:
rows = []
while True:
row = self.fetchone()
@@ -991,10 +1009,10 @@ class _RowIterator(_Closeable):
rows.append(row)
return rows
- def fetch_arrow_table(self):
+ def fetch_arrow_table(self) -> pyarrow.Table:
return self._reader.read_all()
- def fetch_df(self):
+ def fetch_df(self) -> "pandas.DataFrame":
return self._reader.read_pandas()
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py
b/python/adbc_driver_manager/tests/test_dbapi.py
index 1eba12fd..a29a661a 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -294,6 +294,26 @@ def test_executemany(sqlite):
assert next(cur) == (5, 6)
[email protected]
+def test_fetch_record_batch(sqlite):
+ dataset = [
+ [1, 2],
+ [3, 4],
+ [5, 6],
+ [7, 8],
+ [9, 10],
+ ]
+ with sqlite.cursor() as cur:
+ cur.execute("CREATE TABLE foo (a, b)")
+ cur.executemany(
+ "INSERT INTO foo VALUES (?, ?)",
+ dataset,
+ )
+ cur.execute("SELECT * FROM foo")
+ rbr = cur.fetch_record_batch()
+ assert rbr.read_pandas().values.tolist() == dataset
+
+
@pytest.mark.sqlite
def test_fetch_empty(sqlite):
with sqlite.cursor() as cur: