This is an automated email from the ASF dual-hosted git repository. francischuang pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite-avatica-go.git
commit 26a2c5fd1d6b11b1955fedd9dd570e41b768e57b Author: Parag Jain <[email protected]> AuthorDate: Mon Jun 5 19:08:41 2023 +0530 [CALCITE-5754] Fix open statements leak --- connection.go | 20 +++++++++++++++++--- rows.go | 11 ++++++++--- statement.go | 2 +- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/connection.go b/connection.go index 368dc34..22f5b83 100644 --- a/connection.go +++ b/connection.go @@ -145,9 +145,12 @@ func (c *conn) exec(ctx context.Context, query string, args []namedValue) (drive return nil, c.avaticaErrorToResponseErrorOrError(err) } + statementID := st.(*message.CreateStatementResponse).StatementId + defer c.closeStatement(context.Background(), statementID) + res, err := c.httpClient.post(ctx, &message.PrepareAndExecuteRequest{ ConnectionId: c.connectionId, - StatementId: st.(*message.CreateStatementResponse).StatementId, + StatementId: statementID, Sql: query, MaxRowsTotal: c.config.maxRowsTotal, FirstFrameMaxSize: c.config.frameMaxSize, @@ -188,21 +191,24 @@ func (c *conn) query(ctx context.Context, query string, args []namedValue) (driv return nil, c.avaticaErrorToResponseErrorOrError(err) } + statementID := st.(*message.CreateStatementResponse).StatementId + res, err := c.httpClient.post(ctx, &message.PrepareAndExecuteRequest{ ConnectionId: c.connectionId, - StatementId: st.(*message.CreateStatementResponse).StatementId, + StatementId: statementID, Sql: query, MaxRowsTotal: c.config.maxRowsTotal, FirstFrameMaxSize: c.config.frameMaxSize, }) if err != nil { + _ = c.closeStatement(context.Background(), statementID) return nil, c.avaticaErrorToResponseErrorOrError(err) } resultSets := res.(*message.ExecuteResponse).Results - return newRows(c, st.(*message.CreateStatementResponse).StatementId, resultSets), nil + return newRows(c, statementID, true, resultSets), nil } func (c *conn) avaticaErrorToResponseErrorOrError(err error) error { @@ -243,3 +249,11 @@ func (c *conn) ResetSession(_ context.Context) error { } return nil } + +func (c *conn) closeStatement(ctx context.Context, statementID uint32) error { + _, err := c.httpClient.post(context.Background(), &message.CloseStatementRequest{ + ConnectionId: c.connectionId, + StatementId: statementID, + }) + return c.avaticaErrorToResponseErrorOrError(err) +} diff --git a/rows.go b/rows.go index b886b66..dd0cac4 100644 --- a/rows.go +++ b/rows.go @@ -38,6 +38,7 @@ type resultSet struct { type rows struct { conn *conn statementID uint32 + closeStatement bool resultSets []*resultSet currentResultSet int columnNames []string @@ -64,9 +65,12 @@ func (r *rows) Columns() []string { // Close closes the rows iterator. func (r *rows) Close() error { - + var err error + if r.closeStatement { + err = r.conn.closeStatement(context.Background(), r.statementID) + } r.conn = nil - return nil + return err } // Next is called to populate the next row of data into @@ -139,7 +143,7 @@ func (r *rows) Next(dest []driver.Value) error { } // newRows create a new set of rows from a result set. -func newRows(conn *conn, statementID uint32, resultSets []*message.ResultSetResponse) *rows { +func newRows(conn *conn, statementID uint32, closeStatement bool, resultSets []*message.ResultSetResponse) *rows { var rsets []*resultSet @@ -180,6 +184,7 @@ func newRows(conn *conn, statementID uint32, resultSets []*message.ResultSetResp return &rows{ conn: conn, statementID: statementID, + closeStatement: closeStatement, resultSets: rsets, currentResultSet: 0, } diff --git a/statement.go b/statement.go index 12822ef..014b71a 100644 --- a/statement.go +++ b/statement.go @@ -173,7 +173,7 @@ func (s *stmt) query(ctx context.Context, args []namedValue) (driver.Rows, error resultSet := res.(*message.ExecuteResponse).Results - return newRows(s.conn, s.statementID, resultSet), nil + return newRows(s.conn, s.statementID, false, resultSet), nil } func (s *stmt) parametersToTypedValues(vals []namedValue) []*message.TypedValue {
