This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 19a30aea feat(python/adbc_driver_manager): add fetch_record_batch 
(#989)
19a30aea is described below

commit 19a30aea784dfdce1fed506604f0272a222a3294
Author: Solomon Choe <[email protected]>
AuthorDate: Wed Aug 23 11:25:57 2023 -0700

    feat(python/adbc_driver_manager): add fetch_record_batch (#989)
    
    Fixes #968
    
    ---------
    
    Co-authored-by: David Li <[email protected]>
---
 .../adbc_driver_manager/dbapi.py                   | 26 ++++++++++++++++++----
 python/adbc_driver_manager/tests/test_dbapi.py     | 20 +++++++++++++++++
 2 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 31e4392a..60bc2d1b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -926,6 +926,24 @@ class Cursor(_Closeable):
             )
         return self._results.fetch_df()
 
+    def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
+        """
+        Fetch the result as a PyArrow RecordBatchReader.
+
+        This implements a similar API as DuckDB:
+        
https://duckdb.org/docs/guides/python/export_arrow.html#export-as-a-recordbatchreader
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        if self._results is None:
+            raise ProgrammingError(
+                "Cannot fetch_record_batch() before execute()",
+                status_code=_lib.AdbcStatusCode.INVALID_STATE,
+            )
+        return self._results._reader
+
 
 # ----------------------------------------------------------
 # Utilities
@@ -973,7 +991,7 @@ class _RowIterator(_Closeable):
         self.rownumber += 1
         return row
 
-    def fetchmany(self, size: int):
+    def fetchmany(self, size: int) -> List[tuple]:
         rows = []
         for _ in range(size):
             row = self.fetchone()
@@ -982,7 +1000,7 @@ class _RowIterator(_Closeable):
             rows.append(row)
         return rows
 
-    def fetchall(self):
+    def fetchall(self) -> List[tuple]:
         rows = []
         while True:
             row = self.fetchone()
@@ -991,10 +1009,10 @@ class _RowIterator(_Closeable):
             rows.append(row)
         return rows
 
-    def fetch_arrow_table(self):
+    def fetch_arrow_table(self) -> pyarrow.Table:
         return self._reader.read_all()
 
-    def fetch_df(self):
+    def fetch_df(self) -> "pandas.DataFrame":
         return self._reader.read_pandas()
 
 
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py 
b/python/adbc_driver_manager/tests/test_dbapi.py
index 1eba12fd..a29a661a 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -294,6 +294,26 @@ def test_executemany(sqlite):
         assert next(cur) == (5, 6)
 
 
[email protected]
+def test_fetch_record_batch(sqlite):
+    dataset = [
+        [1, 2],
+        [3, 4],
+        [5, 6],
+        [7, 8],
+        [9, 10],
+    ]
+    with sqlite.cursor() as cur:
+        cur.execute("CREATE TABLE foo (a, b)")
+        cur.executemany(
+            "INSERT INTO foo VALUES (?, ?)",
+            dataset,
+        )
+        cur.execute("SELECT * FROM foo")
+        rbr = cur.fetch_record_batch()
+        assert rbr.read_pandas().values.tolist() == dataset
+
+
 @pytest.mark.sqlite
 def test_fetch_empty(sqlite):
     with sqlite.cursor() as cur:

Reply via email to