This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new ab60af3 fix(python/adbc_driver_manager): properly map error codes
(#510)
ab60af3 is described below
commit ab60af3e56a9b2b11a02b6bf837ec53a22ecfd44
Author: David Li <[email protected]>
AuthorDate: Mon Mar 13 12:41:06 2023 -0400
fix(python/adbc_driver_manager): properly map error codes (#510)
Fixes #507.
---
go/adbc/driver/flightsql/flightsql_adbc_test.go | 8 ++--
.../adbc_driver_manager/_lib.pyx | 43 +++++++++++++++++++---
python/adbc_driver_manager/tests/test_lowlevel.py | 38 +++++++++++++++++++
3 files changed, 79 insertions(+), 10 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 664644e..0ff02ef 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -865,7 +865,7 @@ func (ts *TimeoutTestSuite) TestDoActionTimeout() {
ts.Require().NoError(stmt.SetSqlQuery("fetch"))
var adbcErr adbc.Error
ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
- ts.Equal(adbc.StatusTimeout, adbcErr.Code)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
}
func (ts *TimeoutTestSuite) TestDoGetTimeout() {
@@ -880,7 +880,7 @@ func (ts *TimeoutTestSuite) TestDoGetTimeout() {
var adbcErr adbc.Error
_, _, err = stmt.ExecuteQuery(context.Background())
ts.ErrorAs(err, &adbcErr)
- ts.Equal(adbc.StatusTimeout, adbcErr.Code)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
}
func (ts *TimeoutTestSuite) TestDoPutTimeout() {
@@ -895,7 +895,7 @@ func (ts *TimeoutTestSuite) TestDoPutTimeout() {
var adbcErr adbc.Error
_, err = stmt.ExecuteUpdate(context.Background())
ts.ErrorAs(err, &adbcErr)
- ts.Equal(adbc.StatusTimeout, adbcErr.Code)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
}
func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
@@ -910,7 +910,7 @@ func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
var adbcErr adbc.Error
_, _, err = stmt.ExecuteQuery(context.Background())
ts.ErrorAs(err, &adbcErr)
- ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code)
+ ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code, adbcErr.Error())
}
func (ts *TimeoutTestSuite) TestDontTimeout() {
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 85cc258..fac37e8 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -283,8 +283,8 @@ class Error(Exception):
def __init__(self, message, *, status_code, vendor_code=None,
sqlstate=None):
super().__init__(message)
self.status_code = AdbcStatusCode(status_code)
- self.vendor_code = None
- self.sqlstate = None
+ self.vendor_code = vendor_code
+ self.sqlstate = sqlstate
class InterfaceError(Error):
@@ -347,15 +347,23 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError*
error) except *:
message += error.message.decode("utf-8")
if error.vendor_code:
vendor_code = error.vendor_code
+ message += f". Vendor code: {vendor_code}"
if error.sqlstate[0] != 0:
- sqlstate = error.sqlstate.decode("ascii")
+ sqlstate = bytes(error.sqlstate[i] for i in range(5))
+ sqlstate = sqlstate.decode("ascii")
+ message += f". SQLSTATE: {sqlstate}"
if error.release:
error.release(error)
klass = Error
if status in (ADBC_STATUS_INVALID_DATA,):
klass = DataError
- elif status in (ADBC_STATUS_IO, ADBC_STATUS_CANCELLED,
ADBC_STATUS_TIMEOUT):
+ elif status in (
+ ADBC_STATUS_IO,
+ ADBC_STATUS_CANCELLED,
+ ADBC_STATUS_TIMEOUT,
+ ADBC_STATUS_UNKNOWN,
+ ):
klass = OperationalError
elif status in (ADBC_STATUS_INTEGRITY,):
klass = IntegrityError
@@ -364,12 +372,13 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError*
error) except *:
elif status in (ADBC_STATUS_ALREADY_EXISTS,
ADBC_STATUS_INVALID_ARGUMENT,
ADBC_STATUS_INVALID_STATE,
+ ADBC_STATUS_NOT_FOUND,
ADBC_STATUS_UNAUTHENTICATED,
ADBC_STATUS_UNAUTHORIZED):
klass = ProgrammingError
elif status == ADBC_STATUS_NOT_IMPLEMENTED:
- raise NotSupportedError(message)
- raise klass(message, status_code=status)
+ raise NotSupportedError(message, vendor_code=vendor_code,
sqlstate=sqlstate)
+ raise klass(message, status_code=status, vendor_code=vendor_code,
sqlstate=sqlstate)
cdef CAdbcError empty_error():
@@ -386,6 +395,28 @@ cdef bytes _to_bytes(obj, str name):
raise ValueError(f"{name} must be str or bytes")
+def _test_error(status_code, message, vendor_code, sqlstate) -> Error:
+ cdef CAdbcError error
+ error.release = NULL
+
+ message = _to_bytes(message, "message")
+ error.message = message
+
+ if vendor_code:
+ error.vendor_code = vendor_code
+ else:
+ error.vendor_code = 0
+
+ if sqlstate:
+ sqlstate = sqlstate.encode("ascii")
+ else:
+ sqlstate = b"\0\0\0\0\0"
+ for i in range(5):
+ error.sqlstate[i] = sqlstate[i]
+
+ return check_error(AdbcStatusCode(status_code), &error)
+
+
cdef class _AdbcHandle:
"""
Base class for ADBC handles, which are context managers.
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py
b/python/adbc_driver_manager/tests/test_lowlevel.py
index 8013e0d..08209fb 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -58,6 +58,44 @@ def test_database_init():
pass
+def test_error_mapping():
+ import adbc_driver_manager._lib as _lib
+ from adbc_driver_manager import AdbcStatusCode
+
+ cases = [
+ (adbc_driver_manager.OperationalError, AdbcStatusCode.UNKNOWN),
+ (adbc_driver_manager.NotSupportedError,
AdbcStatusCode.NOT_IMPLEMENTED),
+ (adbc_driver_manager.ProgrammingError, AdbcStatusCode.NOT_FOUND),
+ (adbc_driver_manager.ProgrammingError, AdbcStatusCode.ALREADY_EXISTS),
+ (adbc_driver_manager.ProgrammingError,
AdbcStatusCode.INVALID_ARGUMENT),
+ (adbc_driver_manager.ProgrammingError, AdbcStatusCode.INVALID_STATE),
+ (adbc_driver_manager.DataError, AdbcStatusCode.INVALID_DATA),
+ (adbc_driver_manager.IntegrityError, AdbcStatusCode.INTEGRITY),
+ (adbc_driver_manager.InternalError, AdbcStatusCode.INTERNAL),
+ (adbc_driver_manager.OperationalError, AdbcStatusCode.IO),
+ (adbc_driver_manager.OperationalError, AdbcStatusCode.CANCELLED),
+ (adbc_driver_manager.OperationalError, AdbcStatusCode.TIMEOUT),
+ (adbc_driver_manager.ProgrammingError, AdbcStatusCode.UNAUTHENTICATED),
+ (adbc_driver_manager.ProgrammingError, AdbcStatusCode.UNAUTHORIZED),
+ ]
+
+ message = "Message"
+ for (klass, code) in cases:
+ with pytest.raises(klass) as exc_info:
+ _lib._test_error(code, message, vendor_code=None, sqlstate=None)
+ assert message in exc_info.value.args[0]
+ assert exc_info.value.status_code == code
+ assert exc_info.value.vendor_code is None
+ assert exc_info.value.sqlstate is None
+
+ with pytest.raises(klass) as exc_info:
+ _lib._test_error(code, message, vendor_code=42, sqlstate="X0000")
+ assert message in exc_info.value.args[0]
+ assert exc_info.value.status_code == code
+ assert exc_info.value.vendor_code == 42
+ assert exc_info.value.sqlstate == "X0000"
+
+
@pytest.mark.sqlite
def test_database_set_options(sqlite):
db, _ = sqlite