This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a07b846d fix(go/adbc/drivermgr): adjust ingest helper to set target 
before BindStream (#4308)
4a07b846d is described below

commit 4a07b846d681d87a685be957d899e95c33e92b86
Author: Arnold Wakim <[email protected]>
AuthorDate: Thu May 21 09:26:31 2026 +0200

    fix(go/adbc/drivermgr): adjust ingest helper to set target before 
BindStream (#4308)
    
    ## Summary
    
    Reorder `IngestStream` and `IngestStreamContext` to set
    `OptionKeyIngestTargetTable` and `OptionKeyIngestMode` before calling
    `BindStream`, matching drivers that require ingest targets up front
    (e.g., FlightSQL).
    
    ---------
    
    Co-authored-by: arnoldwakim <[email protected]>
---
 .../driver/flightsql/flightsql_adbc_server_test.go | 144 +++++++++++++++++++++
 go/adbc/driver/flightsql/flightsql_statement.go    |  88 +++++++------
 go/adbc/ext.go                                     |  24 ++--
 3 files changed, 204 insertions(+), 52 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 6567c0008..c10b38c88 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -3047,6 +3047,150 @@ func (suite *BulkIngestTests) 
TestBulkIngestWithStream() {
        suite.Equal(int64(5), totalRows)
 }
 
+func (suite *BulkIngestTests) TestBulkIngestBindStreamBeforeOptions() {
+       stmt, err := suite.cnxn.NewStatement()
+       suite.Require().NoError(err)
+       defer validation.CheckedClose(suite.T(), stmt)
+
+       schema := arrow.NewSchema([]arrow.Field{
+               {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: 
false},
+       }, nil)
+
+       bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+       defer bldr.Release()
+
+       bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+       rec1 := bldr.NewRecordBatch()
+       bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{2, 3}, nil)
+       rec2 := bldr.NewRecordBatch()
+       defer rec1.Release()
+       defer rec2.Release()
+
+       rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec1, 
rec2})
+       suite.Require().NoError(err)
+       defer rdr.Release()
+
+       suite.Require().NoError(stmt.BindStream(context.Background(), rdr))
+
+       suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, 
"bind_first"))
+       suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, 
adbc.OptionValueIngestModeCreate))
+
+       nRows, err := stmt.ExecuteUpdate(context.Background())
+       suite.Require().NoError(err)
+       suite.Equal(int64(3), nRows)
+
+       requests := suite.server.GetIngestRequests()
+       suite.Require().Len(requests, 1)
+       suite.Equal("bind_first", requests[0].GetTable())
+}
+
+func (suite *BulkIngestTests) TestBulkIngestBindBeforeOptions() {
+       stmt, err := suite.cnxn.NewStatement()
+       suite.Require().NoError(err)
+       defer validation.CheckedClose(suite.T(), stmt)
+
+       schema := arrow.NewSchema([]arrow.Field{
+               {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
+       }, nil)
+
+       bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+       defer bldr.Release()
+
+       bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{10, 20}, nil)
+       rec := bldr.NewRecordBatch()
+       defer rec.Release()
+
+       suite.Require().NoError(stmt.Bind(context.Background(), rec))
+
+       suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, 
"bind_batch_first"))
+       suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, 
adbc.OptionValueIngestModeCreate))
+
+       nRows, err := stmt.ExecuteUpdate(context.Background())
+       suite.Require().NoError(err)
+       suite.Equal(int64(2), nRows)
+
+       requests := suite.server.GetIngestRequests()
+       suite.Require().Len(requests, 1)
+       suite.Equal("bind_batch_first", requests[0].GetTable())
+}
+
+func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTarget() {
+       stmt, err := suite.cnxn.NewStatement()
+       suite.Require().NoError(err)
+       defer validation.CheckedClose(suite.T(), stmt)
+
+       schema := arrow.NewSchema([]arrow.Field{
+               {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: 
false},
+       }, nil)
+
+       bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+       defer bldr.Release()
+
+       bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+       rec := bldr.NewRecordBatch()
+       defer rec.Release()
+
+       rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
+       suite.Require().NoError(err)
+       defer rdr.Release()
+
+       suite.Require().NoError(stmt.BindStream(context.Background(), rdr))
+
+       _, err = stmt.ExecuteUpdate(context.Background())
+       suite.Require().Error(err)
+       suite.Contains(err.Error(), "must set IngestTargetTable before bulk 
ingestion")
+}
+
+func (suite *BulkIngestTests) TestBulkIngestBindMissingTarget() {
+       stmt, err := suite.cnxn.NewStatement()
+       suite.Require().NoError(err)
+       defer validation.CheckedClose(suite.T(), stmt)
+
+       schema := arrow.NewSchema([]arrow.Field{
+               {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
+       }, nil)
+
+       bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+       defer bldr.Release()
+
+       bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+       rec := bldr.NewRecordBatch()
+       defer rec.Release()
+
+       suite.Require().NoError(stmt.Bind(context.Background(), rec))
+
+       _, err = stmt.ExecuteUpdate(context.Background())
+       suite.Require().Error(err)
+       suite.Contains(err.Error(), "must set IngestTargetTable before bulk 
ingestion")
+}
+
+func (suite *BulkIngestTests) 
TestBulkIngestBindStreamMissingTargetExecuteQuery() {
+       stmt, err := suite.cnxn.NewStatement()
+       suite.Require().NoError(err)
+       defer validation.CheckedClose(suite.T(), stmt)
+
+       schema := arrow.NewSchema([]arrow.Field{
+               {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: 
false},
+       }, nil)
+
+       bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+       defer bldr.Release()
+
+       bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+       rec := bldr.NewRecordBatch()
+       defer rec.Release()
+
+       rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
+       suite.Require().NoError(err)
+       defer rdr.Release()
+
+       suite.Require().NoError(stmt.BindStream(context.Background(), rdr))
+
+       _, _, err = stmt.ExecuteQuery(context.Background())
+       suite.Require().Error(err)
+       suite.Contains(err.Error(), "must set IngestTargetTable before bulk 
ingestion")
+}
+
 func (suite *BulkIngestTests) TestBulkIngestWithoutBind() {
        stmt, err := suite.cnxn.NewStatement()
        suite.Require().NoError(err)
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go 
b/go/adbc/driver/flightsql/flightsql_statement.go
index eb034ec66..08d92c44b 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -503,6 +503,14 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr 
array.RecordReader, n
                return nil, -1, err
        }
 
+       // Reject staged binds if no ingest target was provided
+       if s.targetTable == "" && s.prepared == nil && (s.bound != nil || 
s.streamBind != nil) {
+               return nil, -1, adbc.Error{
+                       Msg:  "[Flight SQL Statement] must set 
IngestTargetTable before bulk ingestion",
+                       Code: adbc.StatusInvalidState,
+               }
+       }
+
        // Handle bulk ingest
        if s.targetTable != "" {
                nrec, err = s.executeIngest(ctx)
@@ -535,6 +543,14 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n 
int64, err error) {
                return -1, err
        }
 
+       // Reject staged binds if no ingest target was provided
+       if s.targetTable == "" && s.prepared == nil && (s.bound != nil || 
s.streamBind != nil) {
+               return -1, adbc.Error{
+                       Msg:  "[Flight SQL Statement] must set 
IngestTargetTable before bulk ingestion",
+                       Code: adbc.StatusInvalidState,
+               }
+       }
+
        // Handle bulk ingest
        if s.targetTable != "" {
                return s.executeIngest(ctx)
@@ -600,66 +616,58 @@ func (s *statement) SetSubstraitPlan(plan []byte) error {
 // but it may not do this until the statement is closed or another
 // record is bound.
 func (s *statement) Bind(_ context.Context, values arrow.RecordBatch) error {
-       // For bulk ingest, bind to the statement
-       if s.targetTable != "" {
-               if s.streamBind != nil {
-                       s.streamBind.Release()
-                       s.streamBind = nil
-               }
-               if s.bound != nil {
-                       s.bound.Release()
-               }
-               s.bound = values
-               if s.bound != nil {
-                       s.bound.Retain()
-               }
+       if s.targetTable != "" || s.prepared == nil {
+               s.setBound(values)
                return nil
        }
 
-       if s.prepared == nil {
-               return adbc.Error{
-                       Msg:  "[Flight SQL Statement] must call Prepare or set 
IngestTargetTable before calling Bind",
-                       Code: adbc.StatusInvalidState}
-       }
-
-       // calls retain
        s.prepared.SetParameters(values)
        return nil
 }
 
+func (s *statement) setBound(values arrow.RecordBatch) {
+       if s.streamBind != nil {
+               s.streamBind.Release()
+               s.streamBind = nil
+       }
+       if s.bound != nil {
+               s.bound.Release()
+       }
+       s.bound = values
+       if s.bound != nil {
+               s.bound.Retain()
+       }
+}
+
 // BindStream uses a record batch stream to bind parameters for this
 // query. This can be used for bulk inserts or prepared statements.
 //
 // The driver will call Release on the record reader, but may not do this
 // until Close is called.
 func (s *statement) BindStream(_ context.Context, stream array.RecordReader) 
error {
-       // For bulk ingest, bind to the statement
-       if s.targetTable != "" {
-               if s.bound != nil {
-                       s.bound.Release()
-                       s.bound = nil
-               }
-               if s.streamBind != nil {
-                       s.streamBind.Release()
-               }
-               s.streamBind = stream
-               if s.streamBind != nil {
-                       s.streamBind.Retain()
-               }
+       if s.targetTable != "" || s.prepared == nil {
+               s.setStreamBound(stream)
                return nil
        }
 
-       if s.prepared == nil {
-               return adbc.Error{
-                       Msg:  "[Flight SQL Statement] must call Prepare or set 
IngestTargetTable before calling Bind",
-                       Code: adbc.StatusInvalidState}
-       }
-
-       // calls retain
        s.prepared.SetRecordReader(stream)
        return nil
 }
 
+func (s *statement) setStreamBound(stream array.RecordReader) {
+       if s.bound != nil {
+               s.bound.Release()
+               s.bound = nil
+       }
+       if s.streamBind != nil {
+               s.streamBind.Release()
+       }
+       s.streamBind = stream
+       if s.streamBind != nil {
+               s.streamBind.Retain()
+       }
+}
+
 // GetParameterSchema returns an Arrow schema representation of
 // the expected parameters to be bound.
 //
diff --git a/go/adbc/ext.go b/go/adbc/ext.go
index dca009709..4f0df0462 100644
--- a/go/adbc/ext.go
+++ b/go/adbc/ext.go
@@ -100,12 +100,7 @@ func IngestStream(ctx context.Context, cnxn Connection, 
reader array.RecordReade
                err = errors.Join(err, stmt.Close())
        }()
 
-       // Bind the record batch stream
-       if err = stmt.BindStream(ctx, reader); err != nil {
-               return -1, fmt.Errorf("error during ingestion: BindStream: %w", 
err)
-       }
-
-       // Set required options
+       // Set required options before binding
        if err = stmt.SetOption(OptionKeyIngestTargetTable, targetTable); err 
!= nil {
                return -1, fmt.Errorf("error during ingestion: 
SetOption(target_table=%s): %w", targetTable, err)
        }
@@ -113,6 +108,11 @@ func IngestStream(ctx context.Context, cnxn Connection, 
reader array.RecordReade
                return -1, fmt.Errorf("error during ingestion: 
SetOption(mode=%s): %w", ingestMode, err)
        }
 
+       // Bind the record batch stream
+       if err = stmt.BindStream(ctx, reader); err != nil {
+               return -1, fmt.Errorf("error during ingestion: BindStream: %w", 
err)
+       }
+
        // Set other options if provided
        if opt.Catalog != "" {
                if err = stmt.SetOption(OptionValueIngestTargetCatalog, 
opt.Catalog); err != nil {
@@ -167,12 +167,7 @@ func IngestStreamContext(ctx context.Context, cnxn 
ConnectionWithContext, reader
                err = errors.Join(err, stmt.Close(ctx))
        }()
 
-       // Bind the record batch stream
-       if err = stmt.BindStream(ctx, reader); err != nil {
-               return -1, fmt.Errorf("error during ingestion: BindStream: %w", 
err)
-       }
-
-       // Set required options
+       // Set required options before binding (some drivers require target 
first)
        if err = stmt.SetOption(ctx, OptionKeyIngestTargetTable, targetTable); 
err != nil {
                return -1, fmt.Errorf("error during ingestion: 
SetOption(target_table=%s): %w", targetTable, err)
        }
@@ -180,6 +175,11 @@ func IngestStreamContext(ctx context.Context, cnxn 
ConnectionWithContext, reader
                return -1, fmt.Errorf("error during ingestion: 
SetOption(mode=%s): %w", ingestMode, err)
        }
 
+       // Bind the record batch stream
+       if err = stmt.BindStream(ctx, reader); err != nil {
+               return -1, fmt.Errorf("error during ingestion: BindStream: %w", 
err)
+       }
+
        // Set other options if provided
        if opt.Catalog != "" {
                if err = stmt.SetOption(ctx, OptionValueIngestTargetCatalog, 
opt.Catalog); err != nil {

Reply via email to