This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new ab60af3  fix(python/adbc_driver_manager): properly map error codes 
(#510)
ab60af3 is described below

commit ab60af3e56a9b2b11a02b6bf837ec53a22ecfd44
Author: David Li <[email protected]>
AuthorDate: Mon Mar 13 12:41:06 2023 -0400

    fix(python/adbc_driver_manager): properly map error codes (#510)
    
    Fixes #507.
---
 go/adbc/driver/flightsql/flightsql_adbc_test.go    |  8 ++--
 .../adbc_driver_manager/_lib.pyx                   | 43 +++++++++++++++++++---
 python/adbc_driver_manager/tests/test_lowlevel.py  | 38 +++++++++++++++++++
 3 files changed, 79 insertions(+), 10 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 664644e..0ff02ef 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -865,7 +865,7 @@ func (ts *TimeoutTestSuite) TestDoActionTimeout() {
        ts.Require().NoError(stmt.SetSqlQuery("fetch"))
        var adbcErr adbc.Error
        ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
-       ts.Equal(adbc.StatusTimeout, adbcErr.Code)
+       ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
 }
 
 func (ts *TimeoutTestSuite) TestDoGetTimeout() {
@@ -880,7 +880,7 @@ func (ts *TimeoutTestSuite) TestDoGetTimeout() {
        var adbcErr adbc.Error
        _, _, err = stmt.ExecuteQuery(context.Background())
        ts.ErrorAs(err, &adbcErr)
-       ts.Equal(adbc.StatusTimeout, adbcErr.Code)
+       ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
 }
 
 func (ts *TimeoutTestSuite) TestDoPutTimeout() {
@@ -895,7 +895,7 @@ func (ts *TimeoutTestSuite) TestDoPutTimeout() {
        var adbcErr adbc.Error
        _, err = stmt.ExecuteUpdate(context.Background())
        ts.ErrorAs(err, &adbcErr)
-       ts.Equal(adbc.StatusTimeout, adbcErr.Code)
+       ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
 }
 
 func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
@@ -910,7 +910,7 @@ func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
        var adbcErr adbc.Error
        _, _, err = stmt.ExecuteQuery(context.Background())
        ts.ErrorAs(err, &adbcErr)
-       ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code)
+       ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code, adbcErr.Error())
 }
 
 func (ts *TimeoutTestSuite) TestDontTimeout() {
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 85cc258..fac37e8 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -283,8 +283,8 @@ class Error(Exception):
     def __init__(self, message, *, status_code, vendor_code=None, 
sqlstate=None):
         super().__init__(message)
         self.status_code = AdbcStatusCode(status_code)
-        self.vendor_code = None
-        self.sqlstate = None
+        self.vendor_code = vendor_code
+        self.sqlstate = sqlstate
 
 
 class InterfaceError(Error):
@@ -347,15 +347,23 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* 
error) except *:
             message += error.message.decode("utf-8")
         if error.vendor_code:
             vendor_code = error.vendor_code
+            message += f". Vendor code: {vendor_code}"
         if error.sqlstate[0] != 0:
-            sqlstate = error.sqlstate.decode("ascii")
+            sqlstate = bytes(error.sqlstate[i] for i in range(5))
+            sqlstate = sqlstate.decode("ascii")
+            message += f". SQLSTATE: {sqlstate}"
         if error.release:
             error.release(error)
 
     klass = Error
     if status in (ADBC_STATUS_INVALID_DATA,):
         klass = DataError
-    elif status in (ADBC_STATUS_IO, ADBC_STATUS_CANCELLED, 
ADBC_STATUS_TIMEOUT):
+    elif status in (
+        ADBC_STATUS_IO,
+        ADBC_STATUS_CANCELLED,
+        ADBC_STATUS_TIMEOUT,
+        ADBC_STATUS_UNKNOWN,
+    ):
         klass = OperationalError
     elif status in (ADBC_STATUS_INTEGRITY,):
         klass = IntegrityError
@@ -364,12 +372,13 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* 
error) except *:
     elif status in (ADBC_STATUS_ALREADY_EXISTS,
                     ADBC_STATUS_INVALID_ARGUMENT,
                     ADBC_STATUS_INVALID_STATE,
+                    ADBC_STATUS_NOT_FOUND,
                     ADBC_STATUS_UNAUTHENTICATED,
                     ADBC_STATUS_UNAUTHORIZED):
         klass = ProgrammingError
     elif status == ADBC_STATUS_NOT_IMPLEMENTED:
-        raise NotSupportedError(message)
-    raise klass(message, status_code=status)
+        raise NotSupportedError(message, vendor_code=vendor_code, 
sqlstate=sqlstate)
+    raise klass(message, status_code=status, vendor_code=vendor_code, 
sqlstate=sqlstate)
 
 
 cdef CAdbcError empty_error():
@@ -386,6 +395,28 @@ cdef bytes _to_bytes(obj, str name):
     raise ValueError(f"{name} must be str or bytes")
 
 
+def _test_error(status_code, message, vendor_code, sqlstate) -> Error:
+    cdef CAdbcError error
+    error.release = NULL
+
+    message = _to_bytes(message, "message")
+    error.message = message
+
+    if vendor_code:
+        error.vendor_code = vendor_code
+    else:
+        error.vendor_code = 0
+
+    if sqlstate:
+        sqlstate = sqlstate.encode("ascii")
+    else:
+        sqlstate = b"\0\0\0\0\0"
+    for i in range(5):
+        error.sqlstate[i] = sqlstate[i]
+
+    return check_error(AdbcStatusCode(status_code), &error)
+
+
 cdef class _AdbcHandle:
     """
     Base class for ADBC handles, which are context managers.
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py 
b/python/adbc_driver_manager/tests/test_lowlevel.py
index 8013e0d..08209fb 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -58,6 +58,44 @@ def test_database_init():
             pass
 
 
+def test_error_mapping():
+    import adbc_driver_manager._lib as _lib
+    from adbc_driver_manager import AdbcStatusCode
+
+    cases = [
+        (adbc_driver_manager.OperationalError, AdbcStatusCode.UNKNOWN),
+        (adbc_driver_manager.NotSupportedError, 
AdbcStatusCode.NOT_IMPLEMENTED),
+        (adbc_driver_manager.ProgrammingError, AdbcStatusCode.NOT_FOUND),
+        (adbc_driver_manager.ProgrammingError, AdbcStatusCode.ALREADY_EXISTS),
+        (adbc_driver_manager.ProgrammingError, 
AdbcStatusCode.INVALID_ARGUMENT),
+        (adbc_driver_manager.ProgrammingError, AdbcStatusCode.INVALID_STATE),
+        (adbc_driver_manager.DataError, AdbcStatusCode.INVALID_DATA),
+        (adbc_driver_manager.IntegrityError, AdbcStatusCode.INTEGRITY),
+        (adbc_driver_manager.InternalError, AdbcStatusCode.INTERNAL),
+        (adbc_driver_manager.OperationalError, AdbcStatusCode.IO),
+        (adbc_driver_manager.OperationalError, AdbcStatusCode.CANCELLED),
+        (adbc_driver_manager.OperationalError, AdbcStatusCode.TIMEOUT),
+        (adbc_driver_manager.ProgrammingError, AdbcStatusCode.UNAUTHENTICATED),
+        (adbc_driver_manager.ProgrammingError, AdbcStatusCode.UNAUTHORIZED),
+    ]
+
+    message = "Message"
+    for (klass, code) in cases:
+        with pytest.raises(klass) as exc_info:
+            _lib._test_error(code, message, vendor_code=None, sqlstate=None)
+        assert message in exc_info.value.args[0]
+        assert exc_info.value.status_code == code
+        assert exc_info.value.vendor_code is None
+        assert exc_info.value.sqlstate is None
+
+        with pytest.raises(klass) as exc_info:
+            _lib._test_error(code, message, vendor_code=42, sqlstate="X0000")
+        assert message in exc_info.value.args[0]
+        assert exc_info.value.status_code == code
+        assert exc_info.value.vendor_code == 42
+        assert exc_info.value.sqlstate == "X0000"
+
+
 @pytest.mark.sqlite
 def test_database_set_options(sqlite):
     db, _ = sqlite

Reply via email to