This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 277bd3c76f GH-34219: [Go][FlightRPC] Add Transactions to Sqlite
FlightSQL example (#34220)
277bd3c76f is described below
commit 277bd3c76fddfda8704982c8d3d50188a460a168
Author: Matt Topol <[email protected]>
AuthorDate: Fri Feb 17 13:29:03 2023 -0500
GH-34219: [Go][FlightRPC] Add Transactions to Sqlite FlightSQL example
(#34220)
### Rationale for this change
### What changes are included in this PR?
Implementations for BeginTransaction/EndTransaction for the SQLite3 Flight
SQL example server along with handling the transactions in Execute and Update
methods.
### Are these changes tested?
Yes, tests are included.
### Are there any user-facing changes?
No.
* Closes: #34219
Authored-by: Matt Topol <[email protected]>
Signed-off-by: Matt Topol <[email protected]>
---
.github/workflows/go.yml | 4 +-
go/arrow/flight/flightsql/example/sqlite_server.go | 112 +++++++++++++++++++--
go/arrow/flight/flightsql/sqlite_server_test.go | 63 ++++++++++++
3 files changed, 170 insertions(+), 9 deletions(-)
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 0d06555b88..bbfbf818b7 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -301,9 +301,9 @@ jobs:
go-version: ${{ matrix.go }}
cache: true
cache-dependency-path: go/go.sum
- - name: Brew Install Arrow
+ - name: Brew Install Arrow and pkg-config
shell: bash
- run: brew install apache-arrow
+ run: brew install apache-arrow pkg-config
- name: Install staticcheck
run: go install honnef.co/go/tools/cmd/staticcheck@${{
matrix.staticcheck }}
- name: Build
diff --git a/go/arrow/flight/flightsql/example/sqlite_server.go
b/go/arrow/flight/flightsql/example/sqlite_server.go
index 0742113000..605c226845 100644
--- a/go/arrow/flight/flightsql/example/sqlite_server.go
+++ b/go/arrow/flight/flightsql/example/sqlite_server.go
@@ -37,6 +37,7 @@
package example
import (
+ "bytes"
"context"
"database/sql"
"fmt"
@@ -174,6 +175,23 @@ func CreateDB() (*sql.DB, error) {
return db, nil
}
+func encodeTransactionQuery(query string, transactionID flightsql.Transaction)
([]byte, error) {
+ return flightsql.CreateStatementQueryTicket(
+ bytes.Join([][]byte{transactionID, []byte(query)}, []byte(":")))
+}
+
+func decodeTransactionQuery(ticket []byte) (txnID, query string, err error) {
+ id, queryBytes, found := bytes.Cut(ticket, []byte(":"))
+ if !found {
+ err = fmt.Errorf("%w: malformed ticket", arrow.ErrInvalid)
+ return
+ }
+
+ txnID = string(id)
+ query = string(queryBytes)
+ return
+}
+
type Statement struct {
stmt *sql.Stmt
params [][]interface{}
@@ -183,7 +201,8 @@ type SQLiteFlightSQLServer struct {
flightsql.BaseServer
db *sql.DB
- prepared sync.Map
+ prepared sync.Map
+ openTransactions sync.Map
}
func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) {
@@ -206,8 +225,8 @@ func (s *SQLiteFlightSQLServer) flightInfoForCommand(desc
*flight.FlightDescript
}
func (s *SQLiteFlightSQLServer) GetFlightInfoStatement(ctx context.Context,
cmd flightsql.StatementQuery, desc *flight.FlightDescriptor)
(*flight.FlightInfo, error) {
- query := cmd.GetQuery()
- tkt, err := flightsql.CreateStatementQueryTicket([]byte(query))
+ query, txnid := cmd.GetQuery(), cmd.GetTransactionId()
+ tkt, err := encodeTransactionQuery(query, txnid)
if err != nil {
return nil, err
}
@@ -221,7 +240,21 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoStatement(ctx
context.Context, cmd
}
func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk,
error) {
- return doGetQuery(ctx, s.Alloc, s.db, string(cmd.GetStatementHandle()),
nil)
+ txnid, query, err := decodeTransactionQuery(cmd.GetStatementHandle())
+ if err != nil {
+ return nil, nil, err
+ }
+
+ var db dbQueryCtx = s.db
+ if txnid != "" {
+ tx, loaded := s.openTransactions.Load(txnid)
+ if !loaded {
+ return nil, nil, fmt.Errorf("%w: invalid transaction id
specified", arrow.ErrInvalid)
+ }
+ db = tx.(*sql.Tx)
+ }
+
+ return doGetQuery(ctx, s.Alloc, db, query, nil)
}
func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc
*flight.FlightDescriptor) (*flight.FlightInfo, error) {
@@ -349,7 +382,22 @@ func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx
context.Context) (*arrow.Sch
}
func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx
context.Context, cmd flightsql.StatementUpdate) (int64, error) {
- res, err := s.db.ExecContext(ctx, cmd.GetQuery())
+ var (
+ res sql.Result
+ err error
+ )
+
+ if len(cmd.GetTransactionId()) > 0 {
+ tx, loaded :=
s.openTransactions.Load(string(cmd.GetTransactionId()))
+ if !loaded {
+ return -1, status.Error(codes.InvalidArgument, "invalid
transaction handle provided")
+ }
+
+ res, err = tx.(*sql.Tx).ExecContext(ctx, cmd.GetQuery())
+ } else {
+ res, err = s.db.ExecContext(ctx, cmd.GetQuery())
+ }
+
if err != nil {
return 0, err
}
@@ -357,7 +405,18 @@ func (s *SQLiteFlightSQLServer)
DoPutCommandStatementUpdate(ctx context.Context,
}
func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context,
req flightsql.ActionCreatePreparedStatementRequest) (result
flightsql.ActionCreatePreparedStatementResult, err error) {
- stmt, err := s.db.PrepareContext(ctx, req.GetQuery())
+ var stmt *sql.Stmt
+
+ if len(req.GetTransactionId()) > 0 {
+ tx, loaded :=
s.openTransactions.Load(string(req.GetTransactionId()))
+ if !loaded {
+ return result, status.Error(codes.InvalidArgument,
"invalid transaction handle provided")
+ }
+ stmt, err = tx.(*sql.Tx).PrepareContext(ctx, req.GetQuery())
+ } else {
+ stmt, err = s.db.PrepareContext(ctx, req.GetQuery())
+ }
+
if err != nil {
return result, err
}
@@ -394,7 +453,11 @@ func (s *SQLiteFlightSQLServer)
GetFlightInfoPreparedStatement(_ context.Context
}, nil
}
-func doGetQuery(ctx context.Context, mem memory.Allocator, db *sql.DB, query
string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan
flight.StreamChunk, error) {
+type dbQueryCtx interface {
+ QueryContext(context.Context, string, ...any) (*sql.Rows, error)
+}
+
+func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx,
query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan
flight.StreamChunk, error) {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, nil, err
@@ -687,3 +750,38 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx
context.Context, cmd fli
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys)
}
+
+func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req
flightsql.ActionBeginTransactionRequest) (id []byte, err error) {
+ tx, err := s.db.Begin()
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "failed to begin
transaction: %s", err.Error())
+ }
+
+ handle := genRandomString()
+ s.openTransactions.Store(string(handle), tx)
+ return handle, nil
+}
+
+func (s *SQLiteFlightSQLServer) EndTransaction(_ context.Context, req
flightsql.ActionEndTransactionRequest) error {
+ if req.GetAction() == flightsql.EndTransactionUnspecified {
+ return status.Error(codes.InvalidArgument, "must specify Commit
or Rollback to end transaction")
+ }
+
+ handle := string(req.GetTransactionId())
+ if tx, loaded := s.openTransactions.LoadAndDelete(handle); loaded {
+ txn := tx.(*sql.Tx)
+ switch req.GetAction() {
+ case flightsql.EndTransactionCommit:
+ if err := txn.Commit(); err != nil {
+ return status.Error(codes.Internal, "failed to
commit transaction: "+err.Error())
+ }
+ case flightsql.EndTransactionRollback:
+ if err := txn.Rollback(); err != nil {
+ return status.Error(codes.Internal, "failed to
rollback transaction: "+err.Error())
+ }
+ }
+ return nil
+ }
+
+ return status.Error(codes.InvalidArgument, "transaction id not found")
+}
diff --git a/go/arrow/flight/flightsql/sqlite_server_test.go
b/go/arrow/flight/flightsql/sqlite_server_test.go
index 7df9d932ed..3e274e32df 100644
--- a/go/arrow/flight/flightsql/sqlite_server_test.go
+++ b/go/arrow/flight/flightsql/sqlite_server_test.go
@@ -36,6 +36,8 @@ import (
"github.com/apache/arrow/go/v12/arrow/scalar"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
sqlite3 "modernc.org/sqlite/lib"
)
@@ -830,6 +832,67 @@ func (s *FlightSqliteServerSuite) TestCommandGetSqlInfo() {
}
}
+func (s *FlightSqliteServerSuite) TestTransactions() {
+ ctx := context.Background()
+ tx, err := s.cl.BeginTransaction(ctx)
+ s.Require().NoError(err)
+ s.Require().NotNil(tx)
+
+ s.True(tx.ID().IsValid())
+ s.NotEmpty(tx.ID())
+
+ _, err = tx.BeginSavepoint(ctx, "foobar")
+ s.Equal(codes.Unimplemented, status.Code(err))
+
+ info, err := tx.Execute(ctx, "SELECT * FROM intTable")
+ s.Require().NoError(err)
+ rdr, err := s.cl.DoGet(ctx, info.Endpoint[0].Ticket)
+ s.Require().NoError(err)
+
+ toTable := func(r *flight.Reader) arrow.Table {
+ defer r.Release()
+ recs := make([]arrow.Record, 0)
+ for rdr.Next() {
+ r := rdr.Record()
+ r.Retain()
+ defer r.Release()
+ recs = append(recs, r)
+ }
+
+ return array.NewTableFromRecords(rdr.Schema(), recs)
+ }
+ tbl := toTable(rdr)
+ defer tbl.Release()
+
+ rowCount := tbl.NumRows()
+
+ result, err := tx.ExecuteUpdate(ctx, `INSERT INTO intTable (keyName,
value) VALUES
+ ('KEYNAME1', 1001),
('KEYNAME2', 1002), ('KEYNAME3', 1003)`)
+ s.Require().NoError(err)
+ s.EqualValues(3, result)
+
+ info, err = tx.Execute(ctx, "SELECT * FROM intTable")
+ s.Require().NoError(err)
+ rdr, err = s.cl.DoGet(ctx, info.Endpoint[0].Ticket)
+ s.Require().NoError(err)
+ tbl = toTable(rdr)
+ defer tbl.Release()
+ s.EqualValues(rowCount+3, tbl.NumRows())
+
+ s.Require().NoError(tx.Rollback(ctx))
+ // commit/rollback invalidates the transaction handle
+ s.ErrorIs(tx.Commit(ctx), flightsql.ErrInvalidTxn)
+ s.ErrorIs(tx.Rollback(ctx), flightsql.ErrInvalidTxn)
+
+ info, err = s.cl.Execute(ctx, "SELECT * FROM intTable")
+ s.Require().NoError(err)
+ rdr, err = s.cl.DoGet(ctx, info.Endpoint[0].Ticket)
+ s.Require().NoError(err)
+ tbl = toTable(rdr)
+ defer tbl.Release()
+ s.EqualValues(rowCount, tbl.NumRows())
+}
+
func TestSqliteServer(t *testing.T) {
suite.Run(t, new(FlightSqliteServerSuite))
}