This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new cc0c4e9c feat(go/adbc/drivermgr): Implement Remaining CGO Wrapper
Methods that are Supported by SQLite Driver (#1304)
cc0c4e9c is described below
commit cc0c4e9c01b1feb7c175b22f7cc1b0d019751745
Author: Joel Lubinitsky <[email protected]>
AuthorDate: Tue Nov 21 08:49:40 2023 -0800
feat(go/adbc/drivermgr): Implement Remaining CGO Wrapper Methods that are
Supported by SQLite Driver (#1304)
# What?
Implementations for the following methods in the CGO wrapper for
`adbc_driver_manager`:
- `GetTableSchema`
- `GetTableTypes`
- `Commit`
- `Rollback`
- `GetParameterSchema`
- `BindStream`
# Why?
Functionality exists in C++ driver manager but not yet accessible via Go
driver interface.
# Notes
Three methods in the wrapper remain unimplemented: `ExecutePartitions`,
`ReadPartition`, and `SetSubstraitPlan`. These methods are not currently
supported by the SQLite driver, which is the primary test target for
these changes. It is still possible to implement them in the drivermgr
wrapper without support in specific drivers, but it does make it more
difficult to verify correct behavior. The effort to add those methods
will likely involve some additional work to ensure we are able to test
their behaviors, so they are being left out of this current round of
implementations.
Closes part of: #1291
---
go/adbc/drivermgr/wrapper.go | 92 +++++++++++++++--
go/adbc/drivermgr/wrapper_sqlite_test.go | 163 ++++++++++++++++++++++++++++++-
2 files changed, 244 insertions(+), 11 deletions(-)
diff --git a/go/adbc/drivermgr/wrapper.go b/go/adbc/drivermgr/wrapper.go
index 1e131e51..07bb94b8 100644
--- a/go/adbc/drivermgr/wrapper.go
+++ b/go/adbc/drivermgr/wrapper.go
@@ -32,6 +32,10 @@ package drivermgr
// return (struct ArrowArray*)malloc(sizeof(struct ArrowArray));
// }
//
+// struct ArrowArrayStream* allocArrStream() {
+// return (struct ArrowArrayStream*)malloc(sizeof(struct
ArrowArrayStream));
+// }
+//
import "C"
import (
"context"
@@ -186,6 +190,15 @@ func getRdr(out *C.struct_ArrowArrayStream)
(array.RecordReader, error) {
return rdr.(array.RecordReader), nil
}
+func getSchema(out *C.struct_ArrowSchema) (*arrow.Schema, error) {
+ // Maybe: ImportCArrowSchema should perform this check?
+ if out.format == nil {
+ return nil, nil
+ }
+
+ return
cdata.ImportCArrowSchema((*cdata.CArrowSchema)(unsafe.Pointer(out)))
+}
+
type cnxn struct {
conn *C.struct_AdbcConnection
}
@@ -255,19 +268,68 @@ func (c *cnxn) GetObjects(_ context.Context, depth
adbc.ObjectDepth, catalog, db
}
func (c *cnxn) GetTableSchema(_ context.Context, catalog, dbSchema *string,
tableName string) (*arrow.Schema, error) {
- return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
+ var (
+ schema C.struct_ArrowSchema
+ err C.struct_AdbcError
+ catalog_ *C.char
+ dbSchema_ *C.char
+ tableName_ *C.char
+ )
+
+ if catalog != nil {
+ catalog_ = C.CString(*catalog)
+ defer C.free(unsafe.Pointer(catalog_))
+ }
+
+ if dbSchema != nil {
+ dbSchema_ = C.CString(*dbSchema)
+ defer C.free(unsafe.Pointer(dbSchema_))
+ }
+
+ tableName_ = C.CString(tableName)
+ defer C.free(unsafe.Pointer(tableName_))
+
+ if code := adbc.Status(C.AdbcConnectionGetTableSchema(c.conn, catalog_,
dbSchema_, tableName_, &schema, &err)); code != adbc.StatusOK {
+ return nil, toAdbcError(code, &err)
+ }
+
+ return getSchema(&schema)
}
func (c *cnxn) GetTableTypes(context.Context) (array.RecordReader, error) {
- return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
+ var (
+ out C.struct_ArrowArrayStream
+ err C.struct_AdbcError
+ )
+
+ if code := adbc.Status(C.AdbcConnectionGetTableTypes(c.conn, &out,
&err)); code != adbc.StatusOK {
+ return nil, toAdbcError(code, &err)
+ }
+ return getRdr(&out)
}
func (c *cnxn) Commit(context.Context) error {
- return &adbc.Error{Code: adbc.StatusNotImplemented}
+ var (
+ err C.struct_AdbcError
+ )
+
+ if code := adbc.Status(C.AdbcConnectionCommit(c.conn, &err)); code !=
adbc.StatusOK {
+ return toAdbcError(code, &err)
+ }
+
+ return nil
}
func (c *cnxn) Rollback(context.Context) error {
- return &adbc.Error{Code: adbc.StatusNotImplemented}
+ var (
+ err C.struct_AdbcError
+ )
+
+ if code := adbc.Status(C.AdbcConnectionRollback(c.conn, &err)); code !=
adbc.StatusOK {
+ return toAdbcError(code, &err)
+ }
+
+ return nil
}
func (c *cnxn) NewStatement() (adbc.Statement, error) {
@@ -405,11 +467,29 @@ func (s *stmt) Bind(_ context.Context, values
arrow.Record) error {
}
func (s *stmt) BindStream(_ context.Context, stream array.RecordReader) error {
- return &adbc.Error{Code: adbc.StatusNotImplemented}
+ var (
+ arrStream = C.allocArrStream()
+ cdArrStream =
(*cdata.CArrowArrayStream)(unsafe.Pointer(arrStream))
+ err C.struct_AdbcError
+ )
+ cdata.ExportRecordReader(stream, cdArrStream)
+ if code := adbc.Status(C.AdbcStatementBindStream(s.st, arrStream,
&err)); code != adbc.StatusOK {
+ return toAdbcError(code, &err)
+ }
+ return nil
}
func (s *stmt) GetParameterSchema() (*arrow.Schema, error) {
- return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
+ var (
+ schema C.struct_ArrowSchema
+ err C.struct_AdbcError
+ )
+
+ if code := adbc.Status(C.AdbcStatementGetParameterSchema(s.st, &schema,
&err)); code != adbc.StatusOK {
+ return nil, toAdbcError(code, &err)
+ }
+
+ return getSchema(&schema)
}
func (s *stmt) ExecutePartitions(context.Context) (*arrow.Schema,
adbc.Partitions, int64, error) {
diff --git a/go/adbc/drivermgr/wrapper_sqlite_test.go
b/go/adbc/drivermgr/wrapper_sqlite_test.go
index 580b0467..c33adf27 100644
--- a/go/adbc/drivermgr/wrapper_sqlite_test.go
+++ b/go/adbc/drivermgr/wrapper_sqlite_test.go
@@ -53,20 +53,25 @@ func (dm *DriverMgrSuite) SetupSuite() {
})
dm.NoError(err)
- db, err := dm.db.Open(dm.ctx)
+ cnxn, err := dm.db.Open(dm.ctx)
dm.NoError(err)
- defer db.Close()
+ defer cnxn.Close()
- stmt, err := db.NewStatement()
+ stmt, err := cnxn.NewStatement()
dm.NoError(err)
defer stmt.Close()
- err = stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY
KEY, name TEXT)")
- dm.NoError(err)
+ dm.NoError(stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER
PRIMARY KEY, name TEXT)"))
nrows, err := stmt.ExecuteUpdate(dm.ctx)
dm.NoError(err)
dm.Equal(int64(0), nrows)
+
+ dm.NoError(stmt.SetSqlQuery("INSERT INTO test_table (id, name) VALUES
(1, 'test')"))
+
+ nrows, err = stmt.ExecuteUpdate(dm.ctx)
+ dm.NoError(err)
+ dm.Equal(int64(1), nrows)
}
func (dm *DriverMgrSuite) SetupTest() {
@@ -334,6 +339,83 @@ func (dm *DriverMgrSuite) TestGetObjectsTableType() {
dm.False(rdr.Next())
}
+func (dm *DriverMgrSuite) TestGetTableSchema() {
+ schema, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "test_table")
+ dm.NoError(err)
+
+ expSchema := arrow.NewSchema(
+ []arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int64,
Nullable: true},
+ {Name: "name", Type: arrow.BinaryTypes.String,
Nullable: true},
+ }, nil)
+ dm.True(expSchema.Equal(schema))
+}
+
+func (dm *DriverMgrSuite) TestGetTableSchemaInvalidTable() {
+ _, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "unknown_table")
+ dm.Error(err)
+}
+
+func (dm *DriverMgrSuite) TestGetTableSchemaCatalog() {
+ catalog := "does_not_exist"
+ schema, err := dm.conn.GetTableSchema(dm.ctx, &catalog, nil,
"test_table")
+ dm.NoError(err)
+ dm.Nil(schema)
+}
+
+func (dm *DriverMgrSuite) TestGetTableSchemaDBSchema() {
+ dbSchema := "does_not_exist"
+ schema, err := dm.conn.GetTableSchema(dm.ctx, nil, &dbSchema,
"test_table")
+ dm.NoError(err)
+ dm.Nil(schema)
+}
+
+func (dm *DriverMgrSuite) TestGetTableTypes() {
+ rdr, err := dm.conn.GetTableTypes(dm.ctx)
+ dm.NoError(err)
+ defer rdr.Release()
+
+ expSchema := adbc.TableTypesSchema
+ dm.True(expSchema.Equal(rdr.Schema()))
+ dm.True(rdr.Next())
+
+ rec := rdr.Record()
+ dm.Equal(int64(2), rec.NumRows())
+
+ expTableTypes := []string{"table", "view"}
+ dm.Contains(expTableTypes, rec.Column(0).ValueStr(0))
+ dm.Contains(expTableTypes, rec.Column(0).ValueStr(1))
+ dm.False(rdr.Next())
+}
+
+func (dm *DriverMgrSuite) TestCommit() {
+ err := dm.conn.Commit(dm.ctx)
+ dm.Error(err)
+ dm.ErrorContains(err, "No active transaction, cannot commit")
+}
+
+func (dm *DriverMgrSuite) TestCommitAutocommitDisabled() {
+ cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
+ dm.True(ok)
+
+ dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit,
adbc.OptionValueDisabled))
+ dm.NoError(dm.conn.Commit(dm.ctx))
+}
+
+func (dm *DriverMgrSuite) TestRollback() {
+ err := dm.conn.Rollback(dm.ctx)
+ dm.Error(err)
+ dm.ErrorContains(err, "No active transaction, cannot rollback")
+}
+
+func (dm *DriverMgrSuite) TestRollbackAutocommitDisabled() {
+ cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
+ dm.True(ok)
+
+ dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit,
adbc.OptionValueDisabled))
+ dm.NoError(dm.conn.Rollback(dm.ctx))
+}
+
func (dm *DriverMgrSuite) TestSqlExecute() {
query := "SELECT 1"
st, err := dm.conn.NewStatement()
@@ -429,6 +511,77 @@ func (dm *DriverMgrSuite) TestSqlPrepareMultipleParams() {
dm.False(rdr.Next())
}
+func (dm *DriverMgrSuite) TestGetParameterSchema() {
+ query := "SELECT ?1, ?2"
+ st, err := dm.conn.NewStatement()
+ dm.Require().NoError(err)
+ dm.Require().NoError(st.SetSqlQuery(query))
+ defer st.Close()
+
+ expSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "?1", Type: arrow.Null, Nullable: true},
+ {Name: "?2", Type: arrow.Null, Nullable: true},
+ }, nil)
+
+ schema, err := st.GetParameterSchema()
+ dm.NoError(err)
+
+ dm.True(expSchema.Equal(schema))
+}
+
+func (dm *DriverMgrSuite) TestBindStream() {
+ query := "SELECT ?1, ?2"
+ st, err := dm.conn.NewStatement()
+ dm.Require().NoError(err)
+ dm.Require().NoError(st.SetSqlQuery(query))
+ defer st.Close()
+
+ schema := arrow.NewSchema([]arrow.Field{
+ {Name: "1", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
+ {Name: "2", Type: arrow.BinaryTypes.String, Nullable: true},
+ }, nil)
+
+ bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+ defer bldr.Release()
+
+ bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil)
+ bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"one",
"two", "three"}, nil)
+
+ rec1 := bldr.NewRecord()
+ defer rec1.Release()
+
+ bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{4, 5, 6}, nil)
+ bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"four",
"five", "six"}, nil)
+
+ rec2 := bldr.NewRecord()
+ defer rec2.Release()
+
+ recsIn := []arrow.Record{rec1, rec2}
+ rdrIn, err := array.NewRecordReader(schema, recsIn)
+ dm.NoError(err)
+
+ dm.NoError(st.BindStream(dm.ctx, rdrIn))
+
+ rdrOut, _, err := st.ExecuteQuery(dm.ctx)
+ dm.NoError(err)
+ defer rdrOut.Release()
+
+ recsOut := make([]arrow.Record, 0)
+ for rdrOut.Next() {
+ rec := rdrOut.Record()
+ rec.Retain()
+ defer rec.Release()
+ recsOut = append(recsOut, rec)
+ }
+
+ tableIn := array.NewTableFromRecords(schema, recsIn)
+ defer tableIn.Release()
+ tableOut := array.NewTableFromRecords(schema, recsOut)
+ defer tableOut.Release()
+
+ dm.Truef(array.TableEqual(tableIn, tableOut), "expected: %s\ngot: %s",
tableIn, tableOut)
+}
+
func TestDriverMgr(t *testing.T) {
suite.Run(t, new(DriverMgrSuite))
}