This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 277bd3c76f GH-34219: [Go][FlightRPC] Add Transactions to Sqlite 
FlightSQL example (#34220)
277bd3c76f is described below

commit 277bd3c76fddfda8704982c8d3d50188a460a168
Author: Matt Topol <[email protected]>
AuthorDate: Fri Feb 17 13:29:03 2023 -0500

    GH-34219: [Go][FlightRPC] Add Transactions to Sqlite FlightSQL example 
(#34220)
    
    
    
    ### Rationale for this change
    
    ### What changes are included in this PR?
    Implementations for BeginTransaction/EndTransaction for the SQLite3 Flight 
SQL example server along with handling the transactions in Execute and Update 
methods.
    
    ### Are these changes tested?
    Yes, tests are included.
    
    ### Are there any user-facing changes?
    No.
    
    * Closes: #34219
    
    Authored-by: Matt Topol <[email protected]>
    Signed-off-by: Matt Topol <[email protected]>
---
 .github/workflows/go.yml                           |   4 +-
 go/arrow/flight/flightsql/example/sqlite_server.go | 112 +++++++++++++++++++--
 go/arrow/flight/flightsql/sqlite_server_test.go    |  63 ++++++++++++
 3 files changed, 170 insertions(+), 9 deletions(-)

diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 0d06555b88..bbfbf818b7 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -301,9 +301,9 @@ jobs:
           go-version: ${{ matrix.go }}
           cache: true
           cache-dependency-path: go/go.sum
-      - name: Brew Install Arrow
+      - name: Brew Install Arrow and pkg-config
         shell: bash
-        run: brew install apache-arrow
+        run: brew install apache-arrow pkg-config
       - name: Install staticcheck
         run: go install honnef.co/go/tools/cmd/staticcheck@${{ 
matrix.staticcheck }}
       - name: Build
diff --git a/go/arrow/flight/flightsql/example/sqlite_server.go 
b/go/arrow/flight/flightsql/example/sqlite_server.go
index 0742113000..605c226845 100644
--- a/go/arrow/flight/flightsql/example/sqlite_server.go
+++ b/go/arrow/flight/flightsql/example/sqlite_server.go
@@ -37,6 +37,7 @@
 package example
 
 import (
+       "bytes"
        "context"
        "database/sql"
        "fmt"
@@ -174,6 +175,23 @@ func CreateDB() (*sql.DB, error) {
        return db, nil
 }
 
+func encodeTransactionQuery(query string, transactionID flightsql.Transaction) 
([]byte, error) {
+       return flightsql.CreateStatementQueryTicket(
+               bytes.Join([][]byte{transactionID, []byte(query)}, []byte(":")))
+}
+
+func decodeTransactionQuery(ticket []byte) (txnID, query string, err error) {
+       id, queryBytes, found := bytes.Cut(ticket, []byte(":"))
+       if !found {
+               err = fmt.Errorf("%w: malformed ticket", arrow.ErrInvalid)
+               return
+       }
+
+       txnID = string(id)
+       query = string(queryBytes)
+       return
+}
+
 type Statement struct {
        stmt   *sql.Stmt
        params [][]interface{}
@@ -183,7 +201,8 @@ type SQLiteFlightSQLServer struct {
        flightsql.BaseServer
        db *sql.DB
 
-       prepared sync.Map
+       prepared         sync.Map
+       openTransactions sync.Map
 }
 
 func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) {
@@ -206,8 +225,8 @@ func (s *SQLiteFlightSQLServer) flightInfoForCommand(desc 
*flight.FlightDescript
 }
 
 func (s *SQLiteFlightSQLServer) GetFlightInfoStatement(ctx context.Context, 
cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) 
(*flight.FlightInfo, error) {
-       query := cmd.GetQuery()
-       tkt, err := flightsql.CreateStatementQueryTicket([]byte(query))
+       query, txnid := cmd.GetQuery(), cmd.GetTransactionId()
+       tkt, err := encodeTransactionQuery(query, txnid)
        if err != nil {
                return nil, err
        }
@@ -221,7 +240,21 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoStatement(ctx 
context.Context, cmd
 }
 
 func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd 
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, 
error) {
-       return doGetQuery(ctx, s.Alloc, s.db, string(cmd.GetStatementHandle()), 
nil)
+       txnid, query, err := decodeTransactionQuery(cmd.GetStatementHandle())
+       if err != nil {
+               return nil, nil, err
+       }
+
+       var db dbQueryCtx = s.db
+       if txnid != "" {
+               tx, loaded := s.openTransactions.Load(txnid)
+               if !loaded {
+                       return nil, nil, fmt.Errorf("%w: invalid transaction id 
specified", arrow.ErrInvalid)
+               }
+               db = tx.(*sql.Tx)
+       }
+
+       return doGetQuery(ctx, s.Alloc, db, query, nil)
 }
 
 func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc 
*flight.FlightDescriptor) (*flight.FlightInfo, error) {
@@ -349,7 +382,22 @@ func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx 
context.Context) (*arrow.Sch
 }
 
 func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx 
context.Context, cmd flightsql.StatementUpdate) (int64, error) {
-       res, err := s.db.ExecContext(ctx, cmd.GetQuery())
+       var (
+               res sql.Result
+               err error
+       )
+
+       if len(cmd.GetTransactionId()) > 0 {
+               tx, loaded := 
s.openTransactions.Load(string(cmd.GetTransactionId()))
+               if !loaded {
+                       return -1, status.Error(codes.InvalidArgument, "invalid 
transaction handle provided")
+               }
+
+               res, err = tx.(*sql.Tx).ExecContext(ctx, cmd.GetQuery())
+       } else {
+               res, err = s.db.ExecContext(ctx, cmd.GetQuery())
+       }
+
        if err != nil {
                return 0, err
        }
@@ -357,7 +405,18 @@ func (s *SQLiteFlightSQLServer) 
DoPutCommandStatementUpdate(ctx context.Context,
 }
 
 func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, 
req flightsql.ActionCreatePreparedStatementRequest) (result 
flightsql.ActionCreatePreparedStatementResult, err error) {
-       stmt, err := s.db.PrepareContext(ctx, req.GetQuery())
+       var stmt *sql.Stmt
+
+       if len(req.GetTransactionId()) > 0 {
+               tx, loaded := 
s.openTransactions.Load(string(req.GetTransactionId()))
+               if !loaded {
+                       return result, status.Error(codes.InvalidArgument, 
"invalid transaction handle provided")
+               }
+               stmt, err = tx.(*sql.Tx).PrepareContext(ctx, req.GetQuery())
+       } else {
+               stmt, err = s.db.PrepareContext(ctx, req.GetQuery())
+       }
+
        if err != nil {
                return result, err
        }
@@ -394,7 +453,11 @@ func (s *SQLiteFlightSQLServer) 
GetFlightInfoPreparedStatement(_ context.Context
        }, nil
 }
 
-func doGetQuery(ctx context.Context, mem memory.Allocator, db *sql.DB, query 
string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan 
flight.StreamChunk, error) {
+type dbQueryCtx interface {
+       QueryContext(context.Context, string, ...any) (*sql.Rows, error)
+}
+
+func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, 
query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan 
flight.StreamChunk, error) {
        rows, err := db.QueryContext(ctx, query, args...)
        if err != nil {
                return nil, nil, err
@@ -687,3 +750,38 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx 
context.Context, cmd fli
        query := prepareQueryForGetKeys(filter)
        return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys)
 }
+
+func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req 
flightsql.ActionBeginTransactionRequest) (id []byte, err error) {
+       tx, err := s.db.Begin()
+       if err != nil {
+               return nil, status.Errorf(codes.Internal, "failed to begin 
transaction: %s", err.Error())
+       }
+
+       handle := genRandomString()
+       s.openTransactions.Store(string(handle), tx)
+       return handle, nil
+}
+
+func (s *SQLiteFlightSQLServer) EndTransaction(_ context.Context, req 
flightsql.ActionEndTransactionRequest) error {
+       if req.GetAction() == flightsql.EndTransactionUnspecified {
+               return status.Error(codes.InvalidArgument, "must specify Commit 
or Rollback to end transaction")
+       }
+
+       handle := string(req.GetTransactionId())
+       if tx, loaded := s.openTransactions.LoadAndDelete(handle); loaded {
+               txn := tx.(*sql.Tx)
+               switch req.GetAction() {
+               case flightsql.EndTransactionCommit:
+                       if err := txn.Commit(); err != nil {
+                               return status.Error(codes.Internal, "failed to 
commit transaction: "+err.Error())
+                       }
+               case flightsql.EndTransactionRollback:
+                       if err := txn.Rollback(); err != nil {
+                               return status.Error(codes.Internal, "failed to 
rollback transaction: "+err.Error())
+                       }
+               }
+               return nil
+       }
+
+       return status.Error(codes.InvalidArgument, "transaction id not found")
+}
diff --git a/go/arrow/flight/flightsql/sqlite_server_test.go 
b/go/arrow/flight/flightsql/sqlite_server_test.go
index 7df9d932ed..3e274e32df 100644
--- a/go/arrow/flight/flightsql/sqlite_server_test.go
+++ b/go/arrow/flight/flightsql/sqlite_server_test.go
@@ -36,6 +36,8 @@ import (
        "github.com/apache/arrow/go/v12/arrow/scalar"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/suite"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
        "google.golang.org/protobuf/proto"
        sqlite3 "modernc.org/sqlite/lib"
 )
@@ -830,6 +832,67 @@ func (s *FlightSqliteServerSuite) TestCommandGetSqlInfo() {
        }
 }
 
+func (s *FlightSqliteServerSuite) TestTransactions() {
+       ctx := context.Background()
+       tx, err := s.cl.BeginTransaction(ctx)
+       s.Require().NoError(err)
+       s.Require().NotNil(tx)
+
+       s.True(tx.ID().IsValid())
+       s.NotEmpty(tx.ID())
+
+       _, err = tx.BeginSavepoint(ctx, "foobar")
+       s.Equal(codes.Unimplemented, status.Code(err))
+
+       info, err := tx.Execute(ctx, "SELECT * FROM intTable")
+       s.Require().NoError(err)
+       rdr, err := s.cl.DoGet(ctx, info.Endpoint[0].Ticket)
+       s.Require().NoError(err)
+
+       toTable := func(r *flight.Reader) arrow.Table {
+               defer r.Release()
+               recs := make([]arrow.Record, 0)
+               for rdr.Next() {
+                       r := rdr.Record()
+                       r.Retain()
+                       defer r.Release()
+                       recs = append(recs, r)
+               }
+
+               return array.NewTableFromRecords(rdr.Schema(), recs)
+       }
+       tbl := toTable(rdr)
+       defer tbl.Release()
+
+       rowCount := tbl.NumRows()
+
+       result, err := tx.ExecuteUpdate(ctx, `INSERT INTO intTable (keyName, 
value) VALUES
+                                                  ('KEYNAME1', 1001), 
('KEYNAME2', 1002), ('KEYNAME3', 1003)`)
+       s.Require().NoError(err)
+       s.EqualValues(3, result)
+
+       info, err = tx.Execute(ctx, "SELECT * FROM intTable")
+       s.Require().NoError(err)
+       rdr, err = s.cl.DoGet(ctx, info.Endpoint[0].Ticket)
+       s.Require().NoError(err)
+       tbl = toTable(rdr)
+       defer tbl.Release()
+       s.EqualValues(rowCount+3, tbl.NumRows())
+
+       s.Require().NoError(tx.Rollback(ctx))
+       // commit/rollback invalidates the transaction handle
+       s.ErrorIs(tx.Commit(ctx), flightsql.ErrInvalidTxn)
+       s.ErrorIs(tx.Rollback(ctx), flightsql.ErrInvalidTxn)
+
+       info, err = s.cl.Execute(ctx, "SELECT * FROM intTable")
+       s.Require().NoError(err)
+       rdr, err = s.cl.DoGet(ctx, info.Endpoint[0].Ticket)
+       s.Require().NoError(err)
+       tbl = toTable(rdr)
+       defer tbl.Release()
+       s.EqualValues(rowCount, tbl.NumRows())
+}
+
 func TestSqliteServer(t *testing.T) {
        suite.Run(t, new(FlightSqliteServerSuite))
 }

Reply via email to