This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a819ea4 test(python/adbc_driver_flightsql): test Flight SQL & error 
details (#1071)
2a819ea4 is described below

commit 2a819ea40a65b8b6f9b7cc887b33f2fe14f4b5a0
Author: David Li <[email protected]>
AuthorDate: Fri Sep 15 14:21:51 2023 -0400

    test(python/adbc_driver_flightsql): test Flight SQL & error details (#1071)
    
    Fixes #1062.
---
 ci/conda_env_python.txt                            |   1 +
 ci/scripts/python_sdist_test.sh                    |   2 +-
 ci/scripts/python_wheel_unix_test.sh               |   2 +-
 ci/scripts/python_wheel_windows_test.bat           |   2 +-
 go/adbc/driver/flightsql/cmd/testserver/main.go    |  95 ++++++++++++-
 .../driver/flightsql/flightsql_adbc_server_test.go |  25 ++--
 go/adbc/driver/flightsql/utils.go                  |  28 ++--
 python/adbc_driver_flightsql/pyproject.toml        |   2 +-
 python/adbc_driver_flightsql/tests/test_dbapi.py   |  15 --
 python/adbc_driver_flightsql/tests/test_errors.py  | 151 +++++++++++++++++++++
 .../adbc_driver_manager/_lib.pxd                   |   1 +
 .../adbc_driver_manager/_lib.pyx                   |  17 ++-
 .../adbc_driver_manager/_reader.pyx                |  10 +-
 13 files changed, 301 insertions(+), 50 deletions(-)

diff --git a/ci/conda_env_python.txt b/ci/conda_env_python.txt
index 63d30fc9..5ebb75ee 100644
--- a/ci/conda_env_python.txt
+++ b/ci/conda_env_python.txt
@@ -24,4 +24,5 @@ setuptools
 
 # For integration testing
 polars
+protobuf
 python-duckdb
diff --git a/ci/scripts/python_sdist_test.sh b/ci/scripts/python_sdist_test.sh
index baaacf74..2739476e 100755
--- a/ci/scripts/python_sdist_test.sh
+++ b/ci/scripts/python_sdist_test.sh
@@ -47,7 +47,7 @@ echo "=== Installing sdists ==="
 for component in ${COMPONENTS}; do
     pip install --no-deps --force-reinstall 
${source_dir}/python/${component}/dist/*.tar.gz
 done
-pip install pytest pyarrow pandas
+pip install pytest pyarrow pandas protobuf
 
 echo "=== (${PYTHON_VERSION}) Testing sdists ==="
 test_packages
diff --git a/ci/scripts/python_wheel_unix_test.sh 
b/ci/scripts/python_wheel_unix_test.sh
index 73e0e741..87943cee 100755
--- a/ci/scripts/python_wheel_unix_test.sh
+++ b/ci/scripts/python_wheel_unix_test.sh
@@ -49,7 +49,7 @@ for component in ${COMPONENTS}; do
         echo "NOTE: assuming wheels are already installed"
     fi
 done
-pip install pytest pyarrow pandas
+pip install pytest pyarrow pandas protobuf
 
 
 echo "=== (${PYTHON_VERSION}) Testing wheels ==="
diff --git a/ci/scripts/python_wheel_windows_test.bat 
b/ci/scripts/python_wheel_windows_test.bat
index f598fbbd..019ce6cf 100644
--- a/ci/scripts/python_wheel_windows_test.bat
+++ b/ci/scripts/python_wheel_windows_test.bat
@@ -27,7 +27,7 @@ FOR %%c IN (adbc_driver_manager adbc_driver_flightsql 
adbc_driver_postgresql adb
     )
 )
 
-pip install pytest pyarrow pandas
+pip install pytest pyarrow pandas protobuf
 
 echo "=== (%PYTHON_VERSION%) Testing wheels ==="
 
diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go 
b/go/adbc/driver/flightsql/cmd/testserver/main.go
index 6e0ca4ff..947fba75 100644
--- a/go/adbc/driver/flightsql/cmd/testserver/main.go
+++ b/go/adbc/driver/flightsql/cmd/testserver/main.go
@@ -22,7 +22,6 @@
 package main
 
 import (
-       "bytes"
        "context"
        "flag"
        "fmt"
@@ -39,23 +38,50 @@ import (
        "github.com/apache/arrow/go/v13/arrow/memory"
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/status"
+       "google.golang.org/protobuf/proto"
+       "google.golang.org/protobuf/types/known/anypb"
+       "google.golang.org/protobuf/types/known/wrapperspb"
 )
 
 type ExampleServer struct {
        flightsql.BaseServer
 }
 
+func StatusWithDetail(code codes.Code, message string, details 
...proto.Message) error {
+       p := status.New(code, message).Proto()
+       // Have to do this by hand because gRPC uses deprecated proto import
+       for _, detail := range details {
+               any, err := anypb.New(detail)
+               if err != nil {
+                       panic(err)
+               }
+               p.Details = append(p.Details, any)
+       }
+       return status.FromProto(p).Err()
+}
+
 func (srv *ExampleServer) ClosePreparedStatement(ctx context.Context, request 
flightsql.ActionClosePreparedStatementRequest) error {
        return nil
 }
 
 func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req 
flightsql.ActionCreatePreparedStatementRequest) (result 
flightsql.ActionCreatePreparedStatementResult, err error) {
+       switch req.GetQuery() {
+       case "error_create_prepared_statement":
+               err = status.Error(codes.InvalidArgument, "expected error 
(DoAction)")
+               return
+       case "error_create_prepared_statement_detail":
+               detail1 := wrapperspb.String("detail1")
+               detail2 := wrapperspb.String("detail2")
+               err = StatusWithDetail(codes.InvalidArgument, "expected error 
(DoAction)", detail1, detail2)
+               return
+       }
        result.Handle = []byte(req.GetQuery())
        return
 }
 
 func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, 
cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) 
(*flight.FlightInfo, error) {
-       if bytes.Equal(cmd.GetPreparedStatementHandle(), 
[]byte("error_do_get")) || bytes.Equal(cmd.GetPreparedStatementHandle(), 
[]byte("error_do_get_stream")) {
+       switch string(cmd.GetPreparedStatementHandle()) {
+       case "error_do_get", "error_do_get_stream", "error_do_get_detail", 
"error_do_get_stream_detail", "forever":
                schema := arrow.NewSchema([]arrow.Field{{Name: "ints", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
                return &flight.FlightInfo{
                        Endpoint:         []*flight.FlightEndpoint{{Ticket: 
&flight.Ticket{Ticket: desc.Cmd}}},
@@ -64,6 +90,12 @@ func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ 
context.Context, cmd
                        TotalBytes:       -1,
                        Schema:           flight.SerializeSchema(schema, 
srv.Alloc),
                }, nil
+       case "error_get_flight_info":
+               return nil, status.Error(codes.InvalidArgument, "expected error 
(GetFlightInfo)")
+       case "error_get_flight_info_detail":
+               detail1 := wrapperspb.String("detail1")
+               detail2 := wrapperspb.String("detail2")
+               return nil, StatusWithDetail(codes.InvalidArgument, "expected 
error (GetFlightInfo)", detail1, detail2)
        }
 
        return &flight.FlightInfo{
@@ -90,8 +122,33 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx 
context.Context, cmd flight
 
 func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd 
flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan 
flight.StreamChunk, err error) {
        log.Printf("DoGetPreparedStatement: %v", 
cmd.GetPreparedStatementHandle())
-       if bytes.Equal(cmd.GetPreparedStatementHandle(), 
[]byte("error_do_get")) {
-               err = status.Error(codes.InvalidArgument, "expected error")
+       switch string(cmd.GetPreparedStatementHandle()) {
+       case "error_do_get":
+               err = status.Error(codes.InvalidArgument, "expected error 
(DoGet)")
+               return
+       case "error_do_get_detail":
+               detail1 := wrapperspb.String("detail1")
+               detail2 := wrapperspb.String("detail2")
+               err = StatusWithDetail(codes.InvalidArgument, "expected error 
(DoGet)", detail1, detail2)
+               return
+       case "forever":
+               ch := make(chan flight.StreamChunk)
+               schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+               var rec arrow.Record
+               rec, _, err = array.RecordFromJSON(memory.DefaultAllocator, 
schema, strings.NewReader(`[{"a": 5}]`))
+               go func() {
+                       // wait for client cancel
+                       <-ctx.Done()
+                       defer close(ch)
+
+                       // arrow-go crashes if we don't give this
+                       ch <- flight.StreamChunk{
+                               Data: rec,
+                               Desc: nil,
+                               Err:  nil,
+                       }
+               }()
+               out = ch
                return
        }
 
@@ -106,11 +163,20 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx 
context.Context, cmd flight
                        Desc: nil,
                        Err:  nil,
                }
-               if bytes.Equal(cmd.GetPreparedStatementHandle(), 
[]byte("error_do_get_stream")) {
+               switch string(cmd.GetPreparedStatementHandle()) {
+               case "error_do_get_stream":
+                       ch <- flight.StreamChunk{
+                               Data: nil,
+                               Desc: nil,
+                               Err:  status.Error(codes.InvalidArgument, 
"expected stream error (DoGet)"),
+                       }
+               case "error_do_get_stream_detail":
+                       detail1 := wrapperspb.String("detail1")
+                       detail2 := wrapperspb.String("detail2")
                        ch <- flight.StreamChunk{
                                Data: nil,
                                Desc: nil,
-                               Err:  status.Error(codes.InvalidArgument, 
"expected error"),
+                               Err:  StatusWithDetail(codes.InvalidArgument, 
"expected stream error (DoGet)", detail1, detail2),
                        }
                }
        }()
@@ -135,6 +201,23 @@ func (srv *ExampleServer) DoGetStatement(ctx 
context.Context, cmd flightsql.Stat
        return
 }
 
+func (srv *ExampleServer) DoPutPreparedStatementQuery(ctx context.Context, cmd 
flightsql.PreparedStatementQuery, reader flight.MessageReader, writer 
flight.MetadataWriter) error {
+       switch string(cmd.GetPreparedStatementHandle()) {
+       case "error_do_put":
+               return status.Error(codes.Unknown, "expected error (DoPut)")
+       case "error_do_put_detail":
+               detail1 := wrapperspb.String("detail1")
+               detail2 := wrapperspb.String("detail2")
+               return StatusWithDetail(codes.Unknown, "expected error 
(DoPut)", detail1, detail2)
+       }
+
+       return status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery 
not implemented")
+}
+
+func (srv *ExampleServer) DoPutPreparedStatementUpdate(context.Context, 
flightsql.PreparedStatementUpdate, flight.MessageReader) (int64, error) {
+       return 0, status.Error(codes.Unimplemented, 
"DoPutPreparedStatementUpdate not implemented")
+}
+
 func main() {
        var (
                host = flag.String("host", "localhost", "hostname to bind to")
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index d43b9fd6..e6adaae1 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -37,12 +37,15 @@ import (
        "github.com/apache/arrow/go/v13/arrow/flight/flightsql"
        "github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref"
        "github.com/apache/arrow/go/v13/arrow/memory"
+       "github.com/golang/protobuf/ptypes/wrappers"
        "github.com/stretchr/testify/suite"
        "golang.org/x/exp/maps"
        "google.golang.org/grpc"
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/metadata"
        "google.golang.org/grpc/status"
+       "google.golang.org/protobuf/proto"
+       "google.golang.org/protobuf/types/known/anypb"
        "google.golang.org/protobuf/types/known/wrapperspb"
 )
 
@@ -287,12 +290,15 @@ func (ts *ErrorDetailsTests) TestGetFlightInfo() {
 
        ts.Equal(1, len(adbcErr.Details))
 
-       wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail)
-       ts.True(ok, "Got message: %#v", wrapper)
+       wrapper := adbcErr.Details[0]
        ts.Equal("grpc-status-details-bin", wrapper.Key())
 
-       message, ok := wrapper.Message.(*wrapperspb.Int32Value)
-       ts.True(ok, "Got message: %#v", message)
+       raw, err := wrapper.Serialize()
+       ts.NoError(err)
+       any := anypb.Any{}
+       ts.NoError(proto.Unmarshal(raw, &any))
+       message := wrappers.Int32Value{}
+       ts.NoError(any.UnmarshalTo(&message))
        ts.Equal(int32(42), message.Value)
 }
 
@@ -319,12 +325,15 @@ func (ts *ErrorDetailsTests) TestDoGet() {
 
        ts.Equal(1, len(adbcErr.Details))
 
-       wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail)
-       ts.True(ok, "Got message: %#v", wrapper)
+       wrapper := adbcErr.Details[0]
        ts.Equal("grpc-status-details-bin", wrapper.Key())
 
-       message, ok := wrapper.Message.(*wrapperspb.Int32Value)
-       ts.True(ok, "Got message: %#v", message)
+       raw, err := wrapper.Serialize()
+       ts.NoError(err)
+       any := anypb.Any{}
+       ts.NoError(proto.Unmarshal(raw, &any))
+       message := wrappers.Int32Value{}
+       ts.NoError(any.UnmarshalTo(&message))
        ts.Equal(int32(42), message.Value)
 }
 
diff --git a/go/adbc/driver/flightsql/utils.go 
b/go/adbc/driver/flightsql/utils.go
index d0d1af85..fef8e738 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -26,6 +26,7 @@ import (
        "google.golang.org/grpc/metadata"
        "google.golang.org/grpc/status"
        "google.golang.org/protobuf/proto"
+       "google.golang.org/protobuf/types/known/anypb"
 )
 
 func adbcFromFlightStatus(err error, context string, args ...any) error {
@@ -81,14 +82,8 @@ func adbcFromFlightStatusWithDetails(err error, header, 
trailer metadata.MD, con
        }
 
        details := []adbc.ErrorDetail{}
-       // slice of proto.Message or error
-       for _, detail := range grpcStatus.Details() {
-               if err, ok := detail.(error); ok {
-                       details = append(details, &adbc.TextErrorDetail{Name: 
"grpc-status-details-bin", Detail: err.Error()})
-               } else if msg, ok := detail.(proto.Message); ok {
-                       details = append(details, 
&adbc.ProtobufErrorDetail{Name: "grpc-status-details-bin", Message: msg})
-               }
-               // else, gRPC returned non-Protobuf detail in violation of 
their method contract
+       for _, detail := range grpcStatus.Proto().Details {
+               details = append(details, &anyErrorDetail{name: 
"grpc-status-details-bin", message: detail})
        }
 
        // XXX(https://github.com/grpc/grpc-go/issues/5485): don't count on
@@ -135,3 +130,20 @@ func checkContext(maybeErr error, ctx context.Context) 
error {
        }
        return ctx.Err()
 }
+
+// grpc's Status derps if you ask it to deserialize the error details, giving
+// you an error for each item. Instead, poke into its internals and directly
+// extract and return the protobuf Any to the client.
+type anyErrorDetail struct {
+       name    string
+       message *anypb.Any
+}
+
+func (d *anyErrorDetail) Key() string {
+       return d.name
+}
+
+// Serialize serializes the Protobuf message (wrapped in Any).
+func (d *anyErrorDetail) Serialize() ([]byte, error) {
+       return proto.Marshal(d.message)
+}
diff --git a/python/adbc_driver_flightsql/pyproject.toml 
b/python/adbc_driver_flightsql/pyproject.toml
index 58f73c77..ba3485f9 100644
--- a/python/adbc_driver_flightsql/pyproject.toml
+++ b/python/adbc_driver_flightsql/pyproject.toml
@@ -29,7 +29,7 @@ dependencies = [
 
 [project.optional-dependencies]
 dbapi = ["pandas", "pyarrow>=8.0.0"]
-test = ["pandas", "pyarrow>=8.0.0", "pytest"]
+test = ["pandas", "protobuf", "pyarrow>=8.0.0", "pytest"]
 
 [project.urls]
 homepage = "https://arrow.apache.org/adbc/";
diff --git a/python/adbc_driver_flightsql/tests/test_dbapi.py 
b/python/adbc_driver_flightsql/tests/test_dbapi.py
index e1990354..0918fc7a 100644
--- a/python/adbc_driver_flightsql/tests/test_dbapi.py
+++ b/python/adbc_driver_flightsql/tests/test_dbapi.py
@@ -33,21 +33,6 @@ def test_query_error(dremio_dbapi):
         assert exc.args[0].startswith("INVALID_ARGUMENT: [FlightSQL] ")
 
 
-def test_query_error_fetch(test_dbapi):
-    with test_dbapi.cursor() as cur:
-        cur.execute("error_do_get")
-        with pytest.raises(Exception, match="expected error"):
-            cur.fetch_arrow_table()
-
-
-def test_query_error_stream(test_dbapi):
-    with test_dbapi.cursor() as cur:
-        cur.execute("error_do_get_stream")
-        with pytest.raises(Exception, match="expected error"):
-            cur.fetchone()
-            cur.fetchone()
-
-
 def test_query_trivial(dremio_dbapi):
     with dremio_dbapi.cursor() as cur:
         cur.execute("SELECT 1")
diff --git a/python/adbc_driver_flightsql/tests/test_errors.py 
b/python/adbc_driver_flightsql/tests/test_errors.py
new file mode 100644
index 00000000..ed44b6a3
--- /dev/null
+++ b/python/adbc_driver_flightsql/tests/test_errors.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+
+import google.protobuf.any_pb2 as any_pb2
+import google.protobuf.wrappers_pb2 as wrappers_pb2
+import pytest
+
+
+def assert_detail(e):
+    # Check that the expected error details are present
+    found = set()
+    for _, detail in e.details:
+        anyproto = any_pb2.Any()
+        anyproto.ParseFromString(detail)
+        string = wrappers_pb2.StringValue()
+        anyproto.Unpack(string)
+        found.add(string.value)
+    assert found == {"detail1", "detail2"}
+
+
+def test_query_cancel(test_dbapi):
+    with test_dbapi.cursor() as cur:
+        cur.execute("forever")
+        cur.adbc_cancel()
+        with pytest.raises(
+            test_dbapi.OperationalError,
+            match=re.escape("CANCELLED: [FlightSQL] context canceled"),
+        ):
+            cur.fetchone()
+
+
+def test_query_error_fetch(test_dbapi):
+    with test_dbapi.cursor() as cur:
+        cur.execute("error_do_get")
+        # Match more exactly to make sure there's not unexpected junk in the 
string
+        with pytest.raises(
+            test_dbapi.ProgrammingError,
+            match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error 
(DoGet)"),
+        ):
+            cur.fetch_arrow_table()
+
+        cur.execute("error_do_get_detail")
+        with pytest.raises(
+            test_dbapi.ProgrammingError,
+            match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error 
(DoGet)"),
+        ) as excval:
+            cur.fetch_arrow_table()
+        assert_detail(excval.value)
+
+
+def test_query_error_stream(test_dbapi):
+    with test_dbapi.cursor() as cur:
+        cur.execute("error_do_get_stream")
+        with pytest.raises(
+            test_dbapi.ProgrammingError,
+            match=re.escape(
+                "INVALID_ARGUMENT: [FlightSQL] expected stream error (DoGet)"
+            ),
+        ):
+            cur.fetchone()
+            cur.fetchone()
+
+        cur.execute("error_do_get_stream_detail")
+        with pytest.raises(
+            test_dbapi.ProgrammingError,
+            match=re.escape(
+                "INVALID_ARGUMENT: [FlightSQL] expected stream error (DoGet)"
+            ),
+        ) as excval:
+            cur.fetchone()
+            cur.fetchone()
+        assert_detail(excval.value)
+
+
+def test_query_error_bind(test_dbapi):
+    with test_dbapi.cursor() as cur:
+        cur.adbc_prepare("error_do_put")
+        with pytest.raises(
+            test_dbapi.OperationalError,
+            match=re.escape("UNKNOWN: [FlightSQL] expected error (DoPut)"),
+        ):
+            cur.execute("error_do_put", parameters=(1, "a"))
+
+        cur.adbc_prepare("error_do_put_detail")
+        with pytest.raises(
+            test_dbapi.OperationalError,
+            match=re.escape("UNKNOWN: [FlightSQL] expected error (DoPut)"),
+        ) as excval:
+            cur.execute("error_do_put_detail", parameters=(1, "a"))
+        assert_detail(excval.value)
+
+
+def test_query_error_create_prepared_statement(test_dbapi):
+    with test_dbapi.cursor() as cur:
+        with pytest.raises(
+            test_dbapi.ProgrammingError,
+            match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error 
(DoAction)"),
+        ):
+            cur.adbc_prepare("error_create_prepared_statement")
+
+        with pytest.raises(
+            test_dbapi.ProgrammingError,
+            match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error 
(DoAction)"),
+        ) as excval:
+            cur.adbc_prepare("error_create_prepared_statement_detail")
+        assert_detail(excval.value)
+
+
+def test_query_error_getflightinfo(test_dbapi):
+    with test_dbapi.cursor() as cur:
+        with pytest.raises(
+            Exception,
+            match=re.escape(
+                "INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
+            ),
+        ):
+            cur.execute("error_get_flight_info")
+
+        with pytest.raises(
+            Exception,
+            match=re.escape(
+                "INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
+            ),
+        ) as excval:
+            cur.execute("error_get_flight_info_detail")
+        assert_detail(excval.value)
+
+        cur.adbc_prepare("error_get_flight_info")
+        with pytest.raises(
+            Exception,
+            match=re.escape(
+                "INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
+            ),
+        ):
+            cur.adbc_execute_partitions("error_get_flight_info")
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
index 88a61a66..358a09aa 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
@@ -282,6 +282,7 @@ cdef const CAdbcError* PyAdbcErrorFromArrayStream(
     CArrowArrayStream* stream, CAdbcStatusCode* status)
 
 cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *
+cdef object convert_error(CAdbcStatusCode status, CAdbcError* error) except *
 
 cdef extern from "adbc_driver_manager.h":
     const char* CAdbcStatusCodeMessage"AdbcStatusCodeMessage"(CAdbcStatusCode 
code)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index e21130a7..ced8870e 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -158,15 +158,16 @@ INGEST_OPTION_MODE_CREATE_APPEND = 
ADBC_INGEST_OPTION_MODE_CREATE_APPEND.decode(
 INGEST_OPTION_TARGET_TABLE = ADBC_INGEST_OPTION_TARGET_TABLE.decode("utf-8")
 
 
-cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *:
+cdef object convert_error(CAdbcStatusCode status, CAdbcError* error) except *:
     cdef CAdbcErrorDetail c_detail
 
     if status == ADBC_STATUS_OK:
-        return
+        return None
 
     message = CAdbcStatusCodeMessage(status).decode("utf-8")
     vendor_code = None
     sqlstate = None
+    details = []
 
     if error != NULL:
         if error.message != NULL:
@@ -181,7 +182,6 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* 
error) except *:
             message += f". SQLSTATE: {sqlstate}"
 
         num_details = AdbcErrorGetDetailCount(error)
-        details = []
         for index in range(num_details):
             c_detail = AdbcErrorGetDetail(error, index)
             if c_detail.key == NULL or c_detail.value == NULL:
@@ -216,8 +216,15 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* 
error) except *:
                     ADBC_STATUS_UNAUTHORIZED):
         klass = ProgrammingError
     elif status == ADBC_STATUS_NOT_IMPLEMENTED:
-        raise NotSupportedError(message, vendor_code=vendor_code, 
sqlstate=sqlstate, details=details)
-    raise klass(message, status_code=status, vendor_code=vendor_code, 
sqlstate=sqlstate, details=details)
+        return NotSupportedError(message, vendor_code=vendor_code, 
sqlstate=sqlstate, details=details)
+    return klass(message, status_code=status, vendor_code=vendor_code, 
sqlstate=sqlstate, details=details)
+
+
+cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *:
+    if status == ADBC_STATUS_OK:
+        return
+
+    raise convert_error(status, error)
 
 
 cdef CAdbcError empty_error():
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx
index 43de6740..e7dcc16f 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx
@@ -30,12 +30,14 @@ cdef class _AdbcErrorHelper:
 
     def check_error(self, exception):
         cdef:
-            CAdbcStatusCode c_status
+            CAdbcStatusCode c_status = ADBC_STATUS_OK
             const CAdbcError* error = 
PyAdbcErrorFromArrayStream(&self.c_stream, &c_status)
-            CAdbcErrorDetail detail
 
-        if error != NULL:
-            check_error(c_status, <CAdbcError*> error)
+        exc = convert_error(c_status, <CAdbcError*> error)
+        if exc is not None:
+            # Suppress "During handling of the above exception, another
+            # exception occurred"
+            raise exc from None
 
         raise exception
 

Reply via email to