This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 4a07b846d fix(go/adbc/drivermgr): adjust ingest helper to set target
before BindStream (#4308)
4a07b846d is described below
commit 4a07b846d681d87a685be957d899e95c33e92b86
Author: Arnold Wakim <[email protected]>
AuthorDate: Thu May 21 09:26:31 2026 +0200
fix(go/adbc/drivermgr): adjust ingest helper to set target before
BindStream (#4308)
## Summary
Reorder `IngestStream` and `IngestStreamContext` to set
`OptionKeyIngestTargetTable` and `OptionKeyIngestMode` before calling
`BindStream`, matching drivers that require ingest targets up front
(e.g., FlightSQL).
---------
Co-authored-by: arnoldwakim <[email protected]>
---
.../driver/flightsql/flightsql_adbc_server_test.go | 144 +++++++++++++++++++++
go/adbc/driver/flightsql/flightsql_statement.go | 88 +++++++------
go/adbc/ext.go | 24 ++--
3 files changed, 204 insertions(+), 52 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 6567c0008..c10b38c88 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -3047,6 +3047,150 @@ func (suite *BulkIngestTests)
TestBulkIngestWithStream() {
suite.Equal(int64(5), totalRows)
}
+func (suite *BulkIngestTests) TestBulkIngestBindStreamBeforeOptions() {
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer validation.CheckedClose(suite.T(), stmt)
+
+ schema := arrow.NewSchema([]arrow.Field{
+ {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable:
false},
+ }, nil)
+
+ bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+ defer bldr.Release()
+
+ bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+ rec1 := bldr.NewRecordBatch()
+ bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{2, 3}, nil)
+ rec2 := bldr.NewRecordBatch()
+ defer rec1.Release()
+ defer rec2.Release()
+
+ rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec1,
rec2})
+ suite.Require().NoError(err)
+ defer rdr.Release()
+
+ suite.Require().NoError(stmt.BindStream(context.Background(), rdr))
+
+ suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable,
"bind_first"))
+ suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode,
adbc.OptionValueIngestModeCreate))
+
+ nRows, err := stmt.ExecuteUpdate(context.Background())
+ suite.Require().NoError(err)
+ suite.Equal(int64(3), nRows)
+
+ requests := suite.server.GetIngestRequests()
+ suite.Require().Len(requests, 1)
+ suite.Equal("bind_first", requests[0].GetTable())
+}
+
+func (suite *BulkIngestTests) TestBulkIngestBindBeforeOptions() {
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer validation.CheckedClose(suite.T(), stmt)
+
+ schema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
+ }, nil)
+
+ bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+ defer bldr.Release()
+
+ bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{10, 20}, nil)
+ rec := bldr.NewRecordBatch()
+ defer rec.Release()
+
+ suite.Require().NoError(stmt.Bind(context.Background(), rec))
+
+ suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable,
"bind_batch_first"))
+ suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode,
adbc.OptionValueIngestModeCreate))
+
+ nRows, err := stmt.ExecuteUpdate(context.Background())
+ suite.Require().NoError(err)
+ suite.Equal(int64(2), nRows)
+
+ requests := suite.server.GetIngestRequests()
+ suite.Require().Len(requests, 1)
+ suite.Equal("bind_batch_first", requests[0].GetTable())
+}
+
+func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTarget() {
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer validation.CheckedClose(suite.T(), stmt)
+
+ schema := arrow.NewSchema([]arrow.Field{
+ {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable:
false},
+ }, nil)
+
+ bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+ defer bldr.Release()
+
+ bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+ rec := bldr.NewRecordBatch()
+ defer rec.Release()
+
+ rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
+ suite.Require().NoError(err)
+ defer rdr.Release()
+
+ suite.Require().NoError(stmt.BindStream(context.Background(), rdr))
+
+ _, err = stmt.ExecuteUpdate(context.Background())
+ suite.Require().Error(err)
+ suite.Contains(err.Error(), "must set IngestTargetTable before bulk
ingestion")
+}
+
+func (suite *BulkIngestTests) TestBulkIngestBindMissingTarget() {
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer validation.CheckedClose(suite.T(), stmt)
+
+ schema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
+ }, nil)
+
+ bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+ defer bldr.Release()
+
+ bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+ rec := bldr.NewRecordBatch()
+ defer rec.Release()
+
+ suite.Require().NoError(stmt.Bind(context.Background(), rec))
+
+ _, err = stmt.ExecuteUpdate(context.Background())
+ suite.Require().Error(err)
+ suite.Contains(err.Error(), "must set IngestTargetTable before bulk
ingestion")
+}
+
+func (suite *BulkIngestTests)
TestBulkIngestBindStreamMissingTargetExecuteQuery() {
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer validation.CheckedClose(suite.T(), stmt)
+
+ schema := arrow.NewSchema([]arrow.Field{
+ {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable:
false},
+ }, nil)
+
+ bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+ defer bldr.Release()
+
+ bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
+ rec := bldr.NewRecordBatch()
+ defer rec.Release()
+
+ rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
+ suite.Require().NoError(err)
+ defer rdr.Release()
+
+ suite.Require().NoError(stmt.BindStream(context.Background(), rdr))
+
+ _, _, err = stmt.ExecuteQuery(context.Background())
+ suite.Require().Error(err)
+ suite.Contains(err.Error(), "must set IngestTargetTable before bulk
ingestion")
+}
+
func (suite *BulkIngestTests) TestBulkIngestWithoutBind() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go
b/go/adbc/driver/flightsql/flightsql_statement.go
index eb034ec66..08d92c44b 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -503,6 +503,14 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr
array.RecordReader, n
return nil, -1, err
}
+ // Reject staged binds if no ingest target was provided
+ if s.targetTable == "" && s.prepared == nil && (s.bound != nil ||
s.streamBind != nil) {
+ return nil, -1, adbc.Error{
+ Msg: "[Flight SQL Statement] must set
IngestTargetTable before bulk ingestion",
+ Code: adbc.StatusInvalidState,
+ }
+ }
+
// Handle bulk ingest
if s.targetTable != "" {
nrec, err = s.executeIngest(ctx)
@@ -535,6 +543,14 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n
int64, err error) {
return -1, err
}
+ // Reject staged binds if no ingest target was provided
+ if s.targetTable == "" && s.prepared == nil && (s.bound != nil ||
s.streamBind != nil) {
+ return -1, adbc.Error{
+ Msg: "[Flight SQL Statement] must set
IngestTargetTable before bulk ingestion",
+ Code: adbc.StatusInvalidState,
+ }
+ }
+
// Handle bulk ingest
if s.targetTable != "" {
return s.executeIngest(ctx)
@@ -600,66 +616,58 @@ func (s *statement) SetSubstraitPlan(plan []byte) error {
// but it may not do this until the statement is closed or another
// record is bound.
func (s *statement) Bind(_ context.Context, values arrow.RecordBatch) error {
- // For bulk ingest, bind to the statement
- if s.targetTable != "" {
- if s.streamBind != nil {
- s.streamBind.Release()
- s.streamBind = nil
- }
- if s.bound != nil {
- s.bound.Release()
- }
- s.bound = values
- if s.bound != nil {
- s.bound.Retain()
- }
+ if s.targetTable != "" || s.prepared == nil {
+ s.setBound(values)
return nil
}
- if s.prepared == nil {
- return adbc.Error{
- Msg: "[Flight SQL Statement] must call Prepare or set
IngestTargetTable before calling Bind",
- Code: adbc.StatusInvalidState}
- }
-
- // calls retain
s.prepared.SetParameters(values)
return nil
}
+func (s *statement) setBound(values arrow.RecordBatch) {
+ if s.streamBind != nil {
+ s.streamBind.Release()
+ s.streamBind = nil
+ }
+ if s.bound != nil {
+ s.bound.Release()
+ }
+ s.bound = values
+ if s.bound != nil {
+ s.bound.Retain()
+ }
+}
+
// BindStream uses a record batch stream to bind parameters for this
// query. This can be used for bulk inserts or prepared statements.
//
// The driver will call Release on the record reader, but may not do this
// until Close is called.
func (s *statement) BindStream(_ context.Context, stream array.RecordReader)
error {
- // For bulk ingest, bind to the statement
- if s.targetTable != "" {
- if s.bound != nil {
- s.bound.Release()
- s.bound = nil
- }
- if s.streamBind != nil {
- s.streamBind.Release()
- }
- s.streamBind = stream
- if s.streamBind != nil {
- s.streamBind.Retain()
- }
+ if s.targetTable != "" || s.prepared == nil {
+ s.setStreamBound(stream)
return nil
}
- if s.prepared == nil {
- return adbc.Error{
- Msg: "[Flight SQL Statement] must call Prepare or set
IngestTargetTable before calling Bind",
- Code: adbc.StatusInvalidState}
- }
-
- // calls retain
s.prepared.SetRecordReader(stream)
return nil
}
+func (s *statement) setStreamBound(stream array.RecordReader) {
+ if s.bound != nil {
+ s.bound.Release()
+ s.bound = nil
+ }
+ if s.streamBind != nil {
+ s.streamBind.Release()
+ }
+ s.streamBind = stream
+ if s.streamBind != nil {
+ s.streamBind.Retain()
+ }
+}
+
// GetParameterSchema returns an Arrow schema representation of
// the expected parameters to be bound.
//
diff --git a/go/adbc/ext.go b/go/adbc/ext.go
index dca009709..4f0df0462 100644
--- a/go/adbc/ext.go
+++ b/go/adbc/ext.go
@@ -100,12 +100,7 @@ func IngestStream(ctx context.Context, cnxn Connection,
reader array.RecordReade
err = errors.Join(err, stmt.Close())
}()
- // Bind the record batch stream
- if err = stmt.BindStream(ctx, reader); err != nil {
- return -1, fmt.Errorf("error during ingestion: BindStream: %w",
err)
- }
-
- // Set required options
+ // Set required options before binding
if err = stmt.SetOption(OptionKeyIngestTargetTable, targetTable); err
!= nil {
return -1, fmt.Errorf("error during ingestion:
SetOption(target_table=%s): %w", targetTable, err)
}
@@ -113,6 +108,11 @@ func IngestStream(ctx context.Context, cnxn Connection,
reader array.RecordReade
return -1, fmt.Errorf("error during ingestion:
SetOption(mode=%s): %w", ingestMode, err)
}
+ // Bind the record batch stream
+ if err = stmt.BindStream(ctx, reader); err != nil {
+ return -1, fmt.Errorf("error during ingestion: BindStream: %w",
err)
+ }
+
// Set other options if provided
if opt.Catalog != "" {
if err = stmt.SetOption(OptionValueIngestTargetCatalog,
opt.Catalog); err != nil {
@@ -167,12 +167,7 @@ func IngestStreamContext(ctx context.Context, cnxn
ConnectionWithContext, reader
err = errors.Join(err, stmt.Close(ctx))
}()
- // Bind the record batch stream
- if err = stmt.BindStream(ctx, reader); err != nil {
- return -1, fmt.Errorf("error during ingestion: BindStream: %w",
err)
- }
-
- // Set required options
+ // Set required options before binding (some drivers require target
first)
if err = stmt.SetOption(ctx, OptionKeyIngestTargetTable, targetTable);
err != nil {
return -1, fmt.Errorf("error during ingestion:
SetOption(target_table=%s): %w", targetTable, err)
}
@@ -180,6 +175,11 @@ func IngestStreamContext(ctx context.Context, cnxn
ConnectionWithContext, reader
return -1, fmt.Errorf("error during ingestion:
SetOption(mode=%s): %w", ingestMode, err)
}
+ // Bind the record batch stream
+ if err = stmt.BindStream(ctx, reader); err != nil {
+ return -1, fmt.Errorf("error during ingestion: BindStream: %w",
err)
+ }
+
// Set other options if provided
if opt.Catalog != "" {
if err = stmt.SetOption(ctx, OptionValueIngestTargetCatalog,
opt.Catalog); err != nil {