This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 2a819ea4 test(python/adbc_driver_flightsql): test Flight SQL & error
details (#1071)
2a819ea4 is described below
commit 2a819ea40a65b8b6f9b7cc887b33f2fe14f4b5a0
Author: David Li <[email protected]>
AuthorDate: Fri Sep 15 14:21:51 2023 -0400
test(python/adbc_driver_flightsql): test Flight SQL & error details (#1071)
Fixes #1062.
---
ci/conda_env_python.txt | 1 +
ci/scripts/python_sdist_test.sh | 2 +-
ci/scripts/python_wheel_unix_test.sh | 2 +-
ci/scripts/python_wheel_windows_test.bat | 2 +-
go/adbc/driver/flightsql/cmd/testserver/main.go | 95 ++++++++++++-
.../driver/flightsql/flightsql_adbc_server_test.go | 25 ++--
go/adbc/driver/flightsql/utils.go | 28 ++--
python/adbc_driver_flightsql/pyproject.toml | 2 +-
python/adbc_driver_flightsql/tests/test_dbapi.py | 15 --
python/adbc_driver_flightsql/tests/test_errors.py | 151 +++++++++++++++++++++
.../adbc_driver_manager/_lib.pxd | 1 +
.../adbc_driver_manager/_lib.pyx | 17 ++-
.../adbc_driver_manager/_reader.pyx | 10 +-
13 files changed, 301 insertions(+), 50 deletions(-)
diff --git a/ci/conda_env_python.txt b/ci/conda_env_python.txt
index 63d30fc9..5ebb75ee 100644
--- a/ci/conda_env_python.txt
+++ b/ci/conda_env_python.txt
@@ -24,4 +24,5 @@ setuptools
# For integration testing
polars
+protobuf
python-duckdb
diff --git a/ci/scripts/python_sdist_test.sh b/ci/scripts/python_sdist_test.sh
index baaacf74..2739476e 100755
--- a/ci/scripts/python_sdist_test.sh
+++ b/ci/scripts/python_sdist_test.sh
@@ -47,7 +47,7 @@ echo "=== Installing sdists ==="
for component in ${COMPONENTS}; do
pip install --no-deps --force-reinstall
${source_dir}/python/${component}/dist/*.tar.gz
done
-pip install pytest pyarrow pandas
+pip install pytest pyarrow pandas protobuf
echo "=== (${PYTHON_VERSION}) Testing sdists ==="
test_packages
diff --git a/ci/scripts/python_wheel_unix_test.sh
b/ci/scripts/python_wheel_unix_test.sh
index 73e0e741..87943cee 100755
--- a/ci/scripts/python_wheel_unix_test.sh
+++ b/ci/scripts/python_wheel_unix_test.sh
@@ -49,7 +49,7 @@ for component in ${COMPONENTS}; do
echo "NOTE: assuming wheels are already installed"
fi
done
-pip install pytest pyarrow pandas
+pip install pytest pyarrow pandas protobuf
echo "=== (${PYTHON_VERSION}) Testing wheels ==="
diff --git a/ci/scripts/python_wheel_windows_test.bat
b/ci/scripts/python_wheel_windows_test.bat
index f598fbbd..019ce6cf 100644
--- a/ci/scripts/python_wheel_windows_test.bat
+++ b/ci/scripts/python_wheel_windows_test.bat
@@ -27,7 +27,7 @@ FOR %%c IN (adbc_driver_manager adbc_driver_flightsql
adbc_driver_postgresql adb
)
)
-pip install pytest pyarrow pandas
+pip install pytest pyarrow pandas protobuf
echo "=== (%PYTHON_VERSION%) Testing wheels ==="
diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go
b/go/adbc/driver/flightsql/cmd/testserver/main.go
index 6e0ca4ff..947fba75 100644
--- a/go/adbc/driver/flightsql/cmd/testserver/main.go
+++ b/go/adbc/driver/flightsql/cmd/testserver/main.go
@@ -22,7 +22,6 @@
package main
import (
- "bytes"
"context"
"flag"
"fmt"
@@ -39,23 +38,50 @@ import (
"github.com/apache/arrow/go/v13/arrow/memory"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/types/known/anypb"
+ "google.golang.org/protobuf/types/known/wrapperspb"
)
type ExampleServer struct {
flightsql.BaseServer
}
+func StatusWithDetail(code codes.Code, message string, details
...proto.Message) error {
+ p := status.New(code, message).Proto()
+ // Have to do this by hand because gRPC uses deprecated proto import
+ for _, detail := range details {
+ any, err := anypb.New(detail)
+ if err != nil {
+ panic(err)
+ }
+ p.Details = append(p.Details, any)
+ }
+ return status.FromProto(p).Err()
+}
+
func (srv *ExampleServer) ClosePreparedStatement(ctx context.Context, request
flightsql.ActionClosePreparedStatementRequest) error {
return nil
}
func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req
flightsql.ActionCreatePreparedStatementRequest) (result
flightsql.ActionCreatePreparedStatementResult, err error) {
+ switch req.GetQuery() {
+ case "error_create_prepared_statement":
+ err = status.Error(codes.InvalidArgument, "expected error
(DoAction)")
+ return
+ case "error_create_prepared_statement_detail":
+ detail1 := wrapperspb.String("detail1")
+ detail2 := wrapperspb.String("detail2")
+ err = StatusWithDetail(codes.InvalidArgument, "expected error
(DoAction)", detail1, detail2)
+ return
+ }
result.Handle = []byte(req.GetQuery())
return
}
func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context,
cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor)
(*flight.FlightInfo, error) {
- if bytes.Equal(cmd.GetPreparedStatementHandle(),
[]byte("error_do_get")) || bytes.Equal(cmd.GetPreparedStatementHandle(),
[]byte("error_do_get_stream")) {
+ switch string(cmd.GetPreparedStatementHandle()) {
+ case "error_do_get", "error_do_get_stream", "error_do_get_detail",
"error_do_get_stream_detail", "forever":
schema := arrow.NewSchema([]arrow.Field{{Name: "ints", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{{Ticket:
&flight.Ticket{Ticket: desc.Cmd}}},
@@ -64,6 +90,12 @@ func (srv *ExampleServer) GetFlightInfoPreparedStatement(_
context.Context, cmd
TotalBytes: -1,
Schema: flight.SerializeSchema(schema,
srv.Alloc),
}, nil
+ case "error_get_flight_info":
+ return nil, status.Error(codes.InvalidArgument, "expected error
(GetFlightInfo)")
+ case "error_get_flight_info_detail":
+ detail1 := wrapperspb.String("detail1")
+ detail2 := wrapperspb.String("detail2")
+ return nil, StatusWithDetail(codes.InvalidArgument, "expected
error (GetFlightInfo)", detail1, detail2)
}
return &flight.FlightInfo{
@@ -90,8 +122,33 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx
context.Context, cmd flight
func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd
flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan
flight.StreamChunk, err error) {
log.Printf("DoGetPreparedStatement: %v",
cmd.GetPreparedStatementHandle())
- if bytes.Equal(cmd.GetPreparedStatementHandle(),
[]byte("error_do_get")) {
- err = status.Error(codes.InvalidArgument, "expected error")
+ switch string(cmd.GetPreparedStatementHandle()) {
+ case "error_do_get":
+ err = status.Error(codes.InvalidArgument, "expected error
(DoGet)")
+ return
+ case "error_do_get_detail":
+ detail1 := wrapperspb.String("detail1")
+ detail2 := wrapperspb.String("detail2")
+ err = StatusWithDetail(codes.InvalidArgument, "expected error
(DoGet)", detail1, detail2)
+ return
+ case "forever":
+ ch := make(chan flight.StreamChunk)
+ schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ var rec arrow.Record
+ rec, _, err = array.RecordFromJSON(memory.DefaultAllocator,
schema, strings.NewReader(`[{"a": 5}]`))
+ go func() {
+ // wait for client cancel
+ <-ctx.Done()
+ defer close(ch)
+
+ // arrow-go crashes if we don't give this
+ ch <- flight.StreamChunk{
+ Data: rec,
+ Desc: nil,
+ Err: nil,
+ }
+ }()
+ out = ch
return
}
@@ -106,11 +163,20 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx
context.Context, cmd flight
Desc: nil,
Err: nil,
}
- if bytes.Equal(cmd.GetPreparedStatementHandle(),
[]byte("error_do_get_stream")) {
+ switch string(cmd.GetPreparedStatementHandle()) {
+ case "error_do_get_stream":
+ ch <- flight.StreamChunk{
+ Data: nil,
+ Desc: nil,
+ Err: status.Error(codes.InvalidArgument,
"expected stream error (DoGet)"),
+ }
+ case "error_do_get_stream_detail":
+ detail1 := wrapperspb.String("detail1")
+ detail2 := wrapperspb.String("detail2")
ch <- flight.StreamChunk{
Data: nil,
Desc: nil,
- Err: status.Error(codes.InvalidArgument,
"expected error"),
+ Err: StatusWithDetail(codes.InvalidArgument,
"expected stream error (DoGet)", detail1, detail2),
}
}
}()
@@ -135,6 +201,23 @@ func (srv *ExampleServer) DoGetStatement(ctx
context.Context, cmd flightsql.Stat
return
}
+func (srv *ExampleServer) DoPutPreparedStatementQuery(ctx context.Context, cmd
flightsql.PreparedStatementQuery, reader flight.MessageReader, writer
flight.MetadataWriter) error {
+ switch string(cmd.GetPreparedStatementHandle()) {
+ case "error_do_put":
+ return status.Error(codes.Unknown, "expected error (DoPut)")
+ case "error_do_put_detail":
+ detail1 := wrapperspb.String("detail1")
+ detail2 := wrapperspb.String("detail2")
+ return StatusWithDetail(codes.Unknown, "expected error
(DoPut)", detail1, detail2)
+ }
+
+ return status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery
not implemented")
+}
+
+func (srv *ExampleServer) DoPutPreparedStatementUpdate(context.Context,
flightsql.PreparedStatementUpdate, flight.MessageReader) (int64, error) {
+ return 0, status.Error(codes.Unimplemented,
"DoPutPreparedStatementUpdate not implemented")
+}
+
func main() {
var (
host = flag.String("host", "localhost", "hostname to bind to")
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index d43b9fd6..e6adaae1 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -37,12 +37,15 @@ import (
"github.com/apache/arrow/go/v13/arrow/flight/flightsql"
"github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow/go/v13/arrow/memory"
+ "github.com/golang/protobuf/ptypes/wrappers"
"github.com/stretchr/testify/suite"
"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
@@ -287,12 +290,15 @@ func (ts *ErrorDetailsTests) TestGetFlightInfo() {
ts.Equal(1, len(adbcErr.Details))
- wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail)
- ts.True(ok, "Got message: %#v", wrapper)
+ wrapper := adbcErr.Details[0]
ts.Equal("grpc-status-details-bin", wrapper.Key())
- message, ok := wrapper.Message.(*wrapperspb.Int32Value)
- ts.True(ok, "Got message: %#v", message)
+ raw, err := wrapper.Serialize()
+ ts.NoError(err)
+ any := anypb.Any{}
+ ts.NoError(proto.Unmarshal(raw, &any))
+ message := wrappers.Int32Value{}
+ ts.NoError(any.UnmarshalTo(&message))
ts.Equal(int32(42), message.Value)
}
@@ -319,12 +325,15 @@ func (ts *ErrorDetailsTests) TestDoGet() {
ts.Equal(1, len(adbcErr.Details))
- wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail)
- ts.True(ok, "Got message: %#v", wrapper)
+ wrapper := adbcErr.Details[0]
ts.Equal("grpc-status-details-bin", wrapper.Key())
- message, ok := wrapper.Message.(*wrapperspb.Int32Value)
- ts.True(ok, "Got message: %#v", message)
+ raw, err := wrapper.Serialize()
+ ts.NoError(err)
+ any := anypb.Any{}
+ ts.NoError(proto.Unmarshal(raw, &any))
+ message := wrappers.Int32Value{}
+ ts.NoError(any.UnmarshalTo(&message))
ts.Equal(int32(42), message.Value)
}
diff --git a/go/adbc/driver/flightsql/utils.go
b/go/adbc/driver/flightsql/utils.go
index d0d1af85..fef8e738 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -26,6 +26,7 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/types/known/anypb"
)
func adbcFromFlightStatus(err error, context string, args ...any) error {
@@ -81,14 +82,8 @@ func adbcFromFlightStatusWithDetails(err error, header,
trailer metadata.MD, con
}
details := []adbc.ErrorDetail{}
- // slice of proto.Message or error
- for _, detail := range grpcStatus.Details() {
- if err, ok := detail.(error); ok {
- details = append(details, &adbc.TextErrorDetail{Name:
"grpc-status-details-bin", Detail: err.Error()})
- } else if msg, ok := detail.(proto.Message); ok {
- details = append(details,
&adbc.ProtobufErrorDetail{Name: "grpc-status-details-bin", Message: msg})
- }
- // else, gRPC returned non-Protobuf detail in violation of
their method contract
+ for _, detail := range grpcStatus.Proto().Details {
+ details = append(details, &anyErrorDetail{name:
"grpc-status-details-bin", message: detail})
}
// XXX(https://github.com/grpc/grpc-go/issues/5485): don't count on
@@ -135,3 +130,20 @@ func checkContext(maybeErr error, ctx context.Context)
error {
}
return ctx.Err()
}
+
+// grpc's Status derps if you ask it to deserialize the error details, giving
+// you an error for each item. Instead, poke into its internals and directly
+// extract and return the protobuf Any to the client.
+type anyErrorDetail struct {
+ name string
+ message *anypb.Any
+}
+
+func (d *anyErrorDetail) Key() string {
+ return d.name
+}
+
+// Serialize serializes the Protobuf message (wrapped in Any).
+func (d *anyErrorDetail) Serialize() ([]byte, error) {
+ return proto.Marshal(d.message)
+}
diff --git a/python/adbc_driver_flightsql/pyproject.toml
b/python/adbc_driver_flightsql/pyproject.toml
index 58f73c77..ba3485f9 100644
--- a/python/adbc_driver_flightsql/pyproject.toml
+++ b/python/adbc_driver_flightsql/pyproject.toml
@@ -29,7 +29,7 @@ dependencies = [
[project.optional-dependencies]
dbapi = ["pandas", "pyarrow>=8.0.0"]
-test = ["pandas", "pyarrow>=8.0.0", "pytest"]
+test = ["pandas", "protobuf", "pyarrow>=8.0.0", "pytest"]
[project.urls]
homepage = "https://arrow.apache.org/adbc/"
diff --git a/python/adbc_driver_flightsql/tests/test_dbapi.py
b/python/adbc_driver_flightsql/tests/test_dbapi.py
index e1990354..0918fc7a 100644
--- a/python/adbc_driver_flightsql/tests/test_dbapi.py
+++ b/python/adbc_driver_flightsql/tests/test_dbapi.py
@@ -33,21 +33,6 @@ def test_query_error(dremio_dbapi):
assert exc.args[0].startswith("INVALID_ARGUMENT: [FlightSQL] ")
-def test_query_error_fetch(test_dbapi):
- with test_dbapi.cursor() as cur:
- cur.execute("error_do_get")
- with pytest.raises(Exception, match="expected error"):
- cur.fetch_arrow_table()
-
-
-def test_query_error_stream(test_dbapi):
- with test_dbapi.cursor() as cur:
- cur.execute("error_do_get_stream")
- with pytest.raises(Exception, match="expected error"):
- cur.fetchone()
- cur.fetchone()
-
-
def test_query_trivial(dremio_dbapi):
with dremio_dbapi.cursor() as cur:
cur.execute("SELECT 1")
diff --git a/python/adbc_driver_flightsql/tests/test_errors.py
b/python/adbc_driver_flightsql/tests/test_errors.py
new file mode 100644
index 00000000..ed44b6a3
--- /dev/null
+++ b/python/adbc_driver_flightsql/tests/test_errors.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+
+import google.protobuf.any_pb2 as any_pb2
+import google.protobuf.wrappers_pb2 as wrappers_pb2
+import pytest
+
+
+def assert_detail(e):
+ # Check that the expected error details are present
+ found = set()
+ for _, detail in e.details:
+ anyproto = any_pb2.Any()
+ anyproto.ParseFromString(detail)
+ string = wrappers_pb2.StringValue()
+ anyproto.Unpack(string)
+ found.add(string.value)
+ assert found == {"detail1", "detail2"}
+
+
+def test_query_cancel(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("forever")
+ cur.adbc_cancel()
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("CANCELLED: [FlightSQL] context canceled"),
+ ):
+ cur.fetchone()
+
+
+def test_query_error_fetch(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("error_do_get")
+ # Match more exactly to make sure there's not unexpected junk in the
string
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error
(DoGet)"),
+ ):
+ cur.fetch_arrow_table()
+
+ cur.execute("error_do_get_detail")
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error
(DoGet)"),
+ ) as excval:
+ cur.fetch_arrow_table()
+ assert_detail(excval.value)
+
+
+def test_query_error_stream(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("error_do_get_stream")
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape(
+ "INVALID_ARGUMENT: [FlightSQL] expected stream error (DoGet)"
+ ),
+ ):
+ cur.fetchone()
+ cur.fetchone()
+
+ cur.execute("error_do_get_stream_detail")
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape(
+ "INVALID_ARGUMENT: [FlightSQL] expected stream error (DoGet)"
+ ),
+ ) as excval:
+ cur.fetchone()
+ cur.fetchone()
+ assert_detail(excval.value)
+
+
+def test_query_error_bind(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.adbc_prepare("error_do_put")
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("UNKNOWN: [FlightSQL] expected error (DoPut)"),
+ ):
+ cur.execute("error_do_put", parameters=(1, "a"))
+
+ cur.adbc_prepare("error_do_put_detail")
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("UNKNOWN: [FlightSQL] expected error (DoPut)"),
+ ) as excval:
+ cur.execute("error_do_put_detail", parameters=(1, "a"))
+ assert_detail(excval.value)
+
+
+def test_query_error_create_prepared_statement(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error
(DoAction)"),
+ ):
+ cur.adbc_prepare("error_create_prepared_statement")
+
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error
(DoAction)"),
+ ) as excval:
+ cur.adbc_prepare("error_create_prepared_statement_detail")
+ assert_detail(excval.value)
+
+
+def test_query_error_getflightinfo(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ with pytest.raises(
+ Exception,
+ match=re.escape(
+ "INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
+ ),
+ ):
+ cur.execute("error_get_flight_info")
+
+ with pytest.raises(
+ Exception,
+ match=re.escape(
+ "INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
+ ),
+ ) as excval:
+ cur.execute("error_get_flight_info_detail")
+ assert_detail(excval.value)
+
+ cur.adbc_prepare("error_get_flight_info")
+ with pytest.raises(
+ Exception,
+ match=re.escape(
+ "INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
+ ),
+ ):
+ cur.adbc_execute_partitions("error_get_flight_info")
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
index 88a61a66..358a09aa 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
@@ -282,6 +282,7 @@ cdef const CAdbcError* PyAdbcErrorFromArrayStream(
CArrowArrayStream* stream, CAdbcStatusCode* status)
cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *
+cdef object convert_error(CAdbcStatusCode status, CAdbcError* error) except *
cdef extern from "adbc_driver_manager.h":
const char* CAdbcStatusCodeMessage"AdbcStatusCodeMessage"(CAdbcStatusCode
code)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index e21130a7..ced8870e 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -158,15 +158,16 @@ INGEST_OPTION_MODE_CREATE_APPEND =
ADBC_INGEST_OPTION_MODE_CREATE_APPEND.decode(
INGEST_OPTION_TARGET_TABLE = ADBC_INGEST_OPTION_TARGET_TABLE.decode("utf-8")
-cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *:
+cdef object convert_error(CAdbcStatusCode status, CAdbcError* error) except *:
cdef CAdbcErrorDetail c_detail
if status == ADBC_STATUS_OK:
- return
+ return None
message = CAdbcStatusCodeMessage(status).decode("utf-8")
vendor_code = None
sqlstate = None
+ details = []
if error != NULL:
if error.message != NULL:
@@ -181,7 +182,6 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError*
error) except *:
message += f". SQLSTATE: {sqlstate}"
num_details = AdbcErrorGetDetailCount(error)
- details = []
for index in range(num_details):
c_detail = AdbcErrorGetDetail(error, index)
if c_detail.key == NULL or c_detail.value == NULL:
@@ -216,8 +216,15 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError*
error) except *:
ADBC_STATUS_UNAUTHORIZED):
klass = ProgrammingError
elif status == ADBC_STATUS_NOT_IMPLEMENTED:
- raise NotSupportedError(message, vendor_code=vendor_code,
sqlstate=sqlstate, details=details)
- raise klass(message, status_code=status, vendor_code=vendor_code,
sqlstate=sqlstate, details=details)
+ return NotSupportedError(message, vendor_code=vendor_code,
sqlstate=sqlstate, details=details)
+ return klass(message, status_code=status, vendor_code=vendor_code,
sqlstate=sqlstate, details=details)
+
+
+cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *:
+ if status == ADBC_STATUS_OK:
+ return
+
+ raise convert_error(status, error)
cdef CAdbcError empty_error():
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx
b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx
index 43de6740..e7dcc16f 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx
@@ -30,12 +30,14 @@ cdef class _AdbcErrorHelper:
def check_error(self, exception):
cdef:
- CAdbcStatusCode c_status
+ CAdbcStatusCode c_status = ADBC_STATUS_OK
const CAdbcError* error =
PyAdbcErrorFromArrayStream(&self.c_stream, &c_status)
- CAdbcErrorDetail detail
- if error != NULL:
- check_error(c_status, <CAdbcError*> error)
+ exc = convert_error(c_status, <CAdbcError*> error)
+ if exc is not None:
+ # Suppress "During handling of the above exception, another
+ # exception occurred"
+ raise exc from None
raise exception