This is an automated email from the ASF dual-hosted git repository.

francischuang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite-avatica-go.git

commit 26a2c5fd1d6b11b1955fedd9dd570e41b768e57b
Author: Parag Jain <[email protected]>
AuthorDate: Mon Jun 5 19:08:41 2023 +0530

    [CALCITE-5754] Fix open statements leak
---
 connection.go | 20 +++++++++++++++++---
 rows.go       | 11 ++++++++---
 statement.go  |  2 +-
 3 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/connection.go b/connection.go
index 368dc34..22f5b83 100644
--- a/connection.go
+++ b/connection.go
@@ -145,9 +145,12 @@ func (c *conn) exec(ctx context.Context, query string, 
args []namedValue) (drive
                return nil, c.avaticaErrorToResponseErrorOrError(err)
        }
 
+       statementID := st.(*message.CreateStatementResponse).StatementId
+       defer c.closeStatement(context.Background(), statementID)
+
        res, err := c.httpClient.post(ctx, &message.PrepareAndExecuteRequest{
                ConnectionId:      c.connectionId,
-               StatementId:       
st.(*message.CreateStatementResponse).StatementId,
+               StatementId:       statementID,
                Sql:               query,
                MaxRowsTotal:      c.config.maxRowsTotal,
                FirstFrameMaxSize: c.config.frameMaxSize,
@@ -188,21 +191,24 @@ func (c *conn) query(ctx context.Context, query string, 
args []namedValue) (driv
                return nil, c.avaticaErrorToResponseErrorOrError(err)
        }
 
+       statementID := st.(*message.CreateStatementResponse).StatementId
+
        res, err := c.httpClient.post(ctx, &message.PrepareAndExecuteRequest{
                ConnectionId:      c.connectionId,
-               StatementId:       
st.(*message.CreateStatementResponse).StatementId,
+               StatementId:       statementID,
                Sql:               query,
                MaxRowsTotal:      c.config.maxRowsTotal,
                FirstFrameMaxSize: c.config.frameMaxSize,
        })
 
        if err != nil {
+               _ = c.closeStatement(context.Background(), statementID)
                return nil, c.avaticaErrorToResponseErrorOrError(err)
        }
 
        resultSets := res.(*message.ExecuteResponse).Results
 
-       return newRows(c, st.(*message.CreateStatementResponse).StatementId, 
resultSets), nil
+       return newRows(c, statementID, true, resultSets), nil
 }
 
 func (c *conn) avaticaErrorToResponseErrorOrError(err error) error {
@@ -243,3 +249,11 @@ func (c *conn) ResetSession(_ context.Context) error {
        }
        return nil
 }
+
+func (c *conn) closeStatement(ctx context.Context, statementID uint32) error {
+       _, err := c.httpClient.post(context.Background(), 
&message.CloseStatementRequest{
+               ConnectionId: c.connectionId,
+               StatementId:  statementID,
+       })
+       return c.avaticaErrorToResponseErrorOrError(err)
+}
diff --git a/rows.go b/rows.go
index b886b66..dd0cac4 100644
--- a/rows.go
+++ b/rows.go
@@ -38,6 +38,7 @@ type resultSet struct {
 type rows struct {
        conn             *conn
        statementID      uint32
+       closeStatement   bool
        resultSets       []*resultSet
        currentResultSet int
        columnNames      []string
@@ -64,9 +65,12 @@ func (r *rows) Columns() []string {
 
 // Close closes the rows iterator.
 func (r *rows) Close() error {
-
+       var err error
+       if r.closeStatement {
+               err = r.conn.closeStatement(context.Background(), r.statementID)
+       }
        r.conn = nil
-       return nil
+       return err
 }
 
 // Next is called to populate the next row of data into
@@ -139,7 +143,7 @@ func (r *rows) Next(dest []driver.Value) error {
 }
 
 // newRows create a new set of rows from a result set.
-func newRows(conn *conn, statementID uint32, resultSets 
[]*message.ResultSetResponse) *rows {
+func newRows(conn *conn, statementID uint32, closeStatement bool, resultSets 
[]*message.ResultSetResponse) *rows {
 
        var rsets []*resultSet
 
@@ -180,6 +184,7 @@ func newRows(conn *conn, statementID uint32, resultSets 
[]*message.ResultSetResp
        return &rows{
                conn:             conn,
                statementID:      statementID,
+               closeStatement:   closeStatement,
                resultSets:       rsets,
                currentResultSet: 0,
        }
diff --git a/statement.go b/statement.go
index 12822ef..014b71a 100644
--- a/statement.go
+++ b/statement.go
@@ -173,7 +173,7 @@ func (s *stmt) query(ctx context.Context, args 
[]namedValue) (driver.Rows, error
 
        resultSet := res.(*message.ExecuteResponse).Results
 
-       return newRows(s.conn, s.statementID, resultSet), nil
+       return newRows(s.conn, s.statementID, false, resultSet), nil
 }
 
 func (s *stmt) parametersToTypedValues(vals []namedValue) 
[]*message.TypedValue {

Reply via email to