This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new cc0c4e9c feat(go/adbc/drivermgr): Implement Remaining CGO Wrapper 
Methods that are Supported by SQLite Driver (#1304)
cc0c4e9c is described below

commit cc0c4e9c01b1feb7c175b22f7cc1b0d019751745
Author: Joel Lubinitsky <[email protected]>
AuthorDate: Tue Nov 21 08:49:40 2023 -0800

    feat(go/adbc/drivermgr): Implement Remaining CGO Wrapper Methods that are 
Supported by SQLite Driver (#1304)
    
    # What?
    Implementations for the following methods in the CGO wrapper for
    `adbc_driver_manager`:
    - `GetTableSchema`
    - `GetTableTypes`
    - `Commit`
    - `Rollback`
    - `GetParameterSchema`
    - `BindStream`
    
    # Why?
    Functionality exists in C++ driver manager but not yet accessible via Go
    driver interface.
    
    # Notes
    Three methods in the wrapper remain unimplemented: `ExecutePartitions`,
    `ReadPartition`, and `SetSubstraitPlan`. These methods are not currently
    supported by the SQLite driver, which is the primary test target for
    these changes. It is still possible to implement them in the drivermgr
    wrapper without support in specific drivers, but it does make it more
    difficult to verify correct behavior. The effort to add those methods
    will likely involve some additional work to ensure we are able to test
    their behaviors, so they are being left out of this current round of
    implementations.
    
    Closes part of: #1291
---
 go/adbc/drivermgr/wrapper.go             |  92 +++++++++++++++--
 go/adbc/drivermgr/wrapper_sqlite_test.go | 163 ++++++++++++++++++++++++++++++-
 2 files changed, 244 insertions(+), 11 deletions(-)

diff --git a/go/adbc/drivermgr/wrapper.go b/go/adbc/drivermgr/wrapper.go
index 1e131e51..07bb94b8 100644
--- a/go/adbc/drivermgr/wrapper.go
+++ b/go/adbc/drivermgr/wrapper.go
@@ -32,6 +32,10 @@ package drivermgr
 //     return (struct ArrowArray*)malloc(sizeof(struct ArrowArray));
 // }
 //
+// struct ArrowArrayStream* allocArrStream() {
+//     return (struct ArrowArrayStream*)malloc(sizeof(struct 
ArrowArrayStream));
+// }
+//
 import "C"
 import (
        "context"
@@ -186,6 +190,15 @@ func getRdr(out *C.struct_ArrowArrayStream) 
(array.RecordReader, error) {
        return rdr.(array.RecordReader), nil
 }
 
+func getSchema(out *C.struct_ArrowSchema) (*arrow.Schema, error) {
+       // Maybe: ImportCArrowSchema should perform this check?
+       if out.format == nil {
+               return nil, nil
+       }
+
+       return 
cdata.ImportCArrowSchema((*cdata.CArrowSchema)(unsafe.Pointer(out)))
+}
+
 type cnxn struct {
        conn *C.struct_AdbcConnection
 }
@@ -255,19 +268,68 @@ func (c *cnxn) GetObjects(_ context.Context, depth 
adbc.ObjectDepth, catalog, db
 }
 
 func (c *cnxn) GetTableSchema(_ context.Context, catalog, dbSchema *string, 
tableName string) (*arrow.Schema, error) {
-       return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
+       var (
+               schema     C.struct_ArrowSchema
+               err        C.struct_AdbcError
+               catalog_   *C.char
+               dbSchema_  *C.char
+               tableName_ *C.char
+       )
+
+       if catalog != nil {
+               catalog_ = C.CString(*catalog)
+               defer C.free(unsafe.Pointer(catalog_))
+       }
+
+       if dbSchema != nil {
+               dbSchema_ = C.CString(*dbSchema)
+               defer C.free(unsafe.Pointer(dbSchema_))
+       }
+
+       tableName_ = C.CString(tableName)
+       defer C.free(unsafe.Pointer(tableName_))
+
+       if code := adbc.Status(C.AdbcConnectionGetTableSchema(c.conn, catalog_, 
dbSchema_, tableName_, &schema, &err)); code != adbc.StatusOK {
+               return nil, toAdbcError(code, &err)
+       }
+
+       return getSchema(&schema)
 }
 
 func (c *cnxn) GetTableTypes(context.Context) (array.RecordReader, error) {
-       return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
+       var (
+               out C.struct_ArrowArrayStream
+               err C.struct_AdbcError
+       )
+
+       if code := adbc.Status(C.AdbcConnectionGetTableTypes(c.conn, &out, 
&err)); code != adbc.StatusOK {
+               return nil, toAdbcError(code, &err)
+       }
+       return getRdr(&out)
 }
 
 func (c *cnxn) Commit(context.Context) error {
-       return &adbc.Error{Code: adbc.StatusNotImplemented}
+       var (
+               err C.struct_AdbcError
+       )
+
+       if code := adbc.Status(C.AdbcConnectionCommit(c.conn, &err)); code != 
adbc.StatusOK {
+               return toAdbcError(code, &err)
+       }
+
+       return nil
 }
 
 func (c *cnxn) Rollback(context.Context) error {
-       return &adbc.Error{Code: adbc.StatusNotImplemented}
+       var (
+               err C.struct_AdbcError
+       )
+
+       if code := adbc.Status(C.AdbcConnectionRollback(c.conn, &err)); code != 
adbc.StatusOK {
+               return toAdbcError(code, &err)
+       }
+
+       return nil
 }
 
 func (c *cnxn) NewStatement() (adbc.Statement, error) {
@@ -405,11 +467,29 @@ func (s *stmt) Bind(_ context.Context, values 
arrow.Record) error {
 }
 
 func (s *stmt) BindStream(_ context.Context, stream array.RecordReader) error {
-       return &adbc.Error{Code: adbc.StatusNotImplemented}
+       var (
+               arrStream   = C.allocArrStream()
+               cdArrStream = 
(*cdata.CArrowArrayStream)(unsafe.Pointer(arrStream))
+               err         C.struct_AdbcError
+       )
+       cdata.ExportRecordReader(stream, cdArrStream)
+       if code := adbc.Status(C.AdbcStatementBindStream(s.st, arrStream, 
&err)); code != adbc.StatusOK {
+               return toAdbcError(code, &err)
+       }
+       return nil
 }
 
 func (s *stmt) GetParameterSchema() (*arrow.Schema, error) {
-       return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
+       var (
+               schema C.struct_ArrowSchema
+               err    C.struct_AdbcError
+       )
+
+       if code := adbc.Status(C.AdbcStatementGetParameterSchema(s.st, &schema, 
&err)); code != adbc.StatusOK {
+               return nil, toAdbcError(code, &err)
+       }
+
+       return getSchema(&schema)
 }
 
 func (s *stmt) ExecutePartitions(context.Context) (*arrow.Schema, 
adbc.Partitions, int64, error) {
diff --git a/go/adbc/drivermgr/wrapper_sqlite_test.go 
b/go/adbc/drivermgr/wrapper_sqlite_test.go
index 580b0467..c33adf27 100644
--- a/go/adbc/drivermgr/wrapper_sqlite_test.go
+++ b/go/adbc/drivermgr/wrapper_sqlite_test.go
@@ -53,20 +53,25 @@ func (dm *DriverMgrSuite) SetupSuite() {
        })
        dm.NoError(err)
 
-       db, err := dm.db.Open(dm.ctx)
+       cnxn, err := dm.db.Open(dm.ctx)
        dm.NoError(err)
-       defer db.Close()
+       defer cnxn.Close()
 
-       stmt, err := db.NewStatement()
+       stmt, err := cnxn.NewStatement()
        dm.NoError(err)
        defer stmt.Close()
 
-       err = stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY 
KEY, name TEXT)")
-       dm.NoError(err)
+       dm.NoError(stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER 
PRIMARY KEY, name TEXT)"))
 
        nrows, err := stmt.ExecuteUpdate(dm.ctx)
        dm.NoError(err)
        dm.Equal(int64(0), nrows)
+
+       dm.NoError(stmt.SetSqlQuery("INSERT INTO test_table (id, name) VALUES 
(1, 'test')"))
+
+       nrows, err = stmt.ExecuteUpdate(dm.ctx)
+       dm.NoError(err)
+       dm.Equal(int64(1), nrows)
 }
 
 func (dm *DriverMgrSuite) SetupTest() {
@@ -334,6 +339,83 @@ func (dm *DriverMgrSuite) TestGetObjectsTableType() {
        dm.False(rdr.Next())
 }
 
+func (dm *DriverMgrSuite) TestGetTableSchema() {
+       schema, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "test_table")
+       dm.NoError(err)
+
+       expSchema := arrow.NewSchema(
+               []arrow.Field{
+                       {Name: "id", Type: arrow.PrimitiveTypes.Int64, 
Nullable: true},
+                       {Name: "name", Type: arrow.BinaryTypes.String, 
Nullable: true},
+               }, nil)
+       dm.True(expSchema.Equal(schema))
+}
+
+func (dm *DriverMgrSuite) TestGetTableSchemaInvalidTable() {
+       _, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "unknown_table")
+       dm.Error(err)
+}
+
+func (dm *DriverMgrSuite) TestGetTableSchemaCatalog() {
+       catalog := "does_not_exist"
+       schema, err := dm.conn.GetTableSchema(dm.ctx, &catalog, nil, 
"test_table")
+       dm.NoError(err)
+       dm.Nil(schema)
+}
+
+func (dm *DriverMgrSuite) TestGetTableSchemaDBSchema() {
+       dbSchema := "does_not_exist"
+       schema, err := dm.conn.GetTableSchema(dm.ctx, nil, &dbSchema, 
"test_table")
+       dm.NoError(err)
+       dm.Nil(schema)
+}
+
+func (dm *DriverMgrSuite) TestGetTableTypes() {
+       rdr, err := dm.conn.GetTableTypes(dm.ctx)
+       dm.NoError(err)
+       defer rdr.Release()
+
+       expSchema := adbc.TableTypesSchema
+       dm.True(expSchema.Equal(rdr.Schema()))
+       dm.True(rdr.Next())
+
+       rec := rdr.Record()
+       dm.Equal(int64(2), rec.NumRows())
+
+       expTableTypes := []string{"table", "view"}
+       dm.Contains(expTableTypes, rec.Column(0).ValueStr(0))
+       dm.Contains(expTableTypes, rec.Column(0).ValueStr(1))
+       dm.False(rdr.Next())
+}
+
+func (dm *DriverMgrSuite) TestCommit() {
+       err := dm.conn.Commit(dm.ctx)
+       dm.Error(err)
+       dm.ErrorContains(err, "No active transaction, cannot commit")
+}
+
+func (dm *DriverMgrSuite) TestCommitAutocommitDisabled() {
+       cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
+       dm.True(ok)
+
+       dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, 
adbc.OptionValueDisabled))
+       dm.NoError(dm.conn.Commit(dm.ctx))
+}
+
+func (dm *DriverMgrSuite) TestRollback() {
+       err := dm.conn.Rollback(dm.ctx)
+       dm.Error(err)
+       dm.ErrorContains(err, "No active transaction, cannot rollback")
+}
+
+func (dm *DriverMgrSuite) TestRollbackAutocommitDisabled() {
+       cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
+       dm.True(ok)
+
+       dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, 
adbc.OptionValueDisabled))
+       dm.NoError(dm.conn.Rollback(dm.ctx))
+}
+
 func (dm *DriverMgrSuite) TestSqlExecute() {
        query := "SELECT 1"
        st, err := dm.conn.NewStatement()
@@ -429,6 +511,77 @@ func (dm *DriverMgrSuite) TestSqlPrepareMultipleParams() {
        dm.False(rdr.Next())
 }
 
+func (dm *DriverMgrSuite) TestGetParameterSchema() {
+       query := "SELECT ?1, ?2"
+       st, err := dm.conn.NewStatement()
+       dm.Require().NoError(err)
+       dm.Require().NoError(st.SetSqlQuery(query))
+       defer st.Close()
+
+       expSchema := arrow.NewSchema([]arrow.Field{
+               {Name: "?1", Type: arrow.Null, Nullable: true},
+               {Name: "?2", Type: arrow.Null, Nullable: true},
+       }, nil)
+
+       schema, err := st.GetParameterSchema()
+       dm.NoError(err)
+
+       dm.True(expSchema.Equal(schema))
+}
+
+func (dm *DriverMgrSuite) TestBindStream() {
+       query := "SELECT ?1, ?2"
+       st, err := dm.conn.NewStatement()
+       dm.Require().NoError(err)
+       dm.Require().NoError(st.SetSqlQuery(query))
+       defer st.Close()
+
+       schema := arrow.NewSchema([]arrow.Field{
+               {Name: "1", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
+               {Name: "2", Type: arrow.BinaryTypes.String, Nullable: true},
+       }, nil)
+
+       bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+       defer bldr.Release()
+
+       bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil)
+       bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"one", 
"two", "three"}, nil)
+
+       rec1 := bldr.NewRecord()
+       defer rec1.Release()
+
+       bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{4, 5, 6}, nil)
+       bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"four", 
"five", "six"}, nil)
+
+       rec2 := bldr.NewRecord()
+       defer rec2.Release()
+
+       recsIn := []arrow.Record{rec1, rec2}
+       rdrIn, err := array.NewRecordReader(schema, recsIn)
+       dm.NoError(err)
+
+       dm.NoError(st.BindStream(dm.ctx, rdrIn))
+
+       rdrOut, _, err := st.ExecuteQuery(dm.ctx)
+       dm.NoError(err)
+       defer rdrOut.Release()
+
+       recsOut := make([]arrow.Record, 0)
+       for rdrOut.Next() {
+               rec := rdrOut.Record()
+               rec.Retain()
+               defer rec.Release()
+               recsOut = append(recsOut, rec)
+       }
+
+       tableIn := array.NewTableFromRecords(schema, recsIn)
+       defer tableIn.Release()
+       tableOut := array.NewTableFromRecords(schema, recsOut)
+       defer tableOut.Release()
+
+       dm.Truef(array.TableEqual(tableIn, tableOut), "expected: %s\ngot: %s", 
tableIn, tableOut)
+}
+
 func TestDriverMgr(t *testing.T) {
        suite.Run(t, new(DriverMgrSuite))
 }

Reply via email to