This is an automated email from the ASF dual-hosted git repository.

neuyilan pushed a commit to branch fix_session_id_2
in repository https://gitbox.apache.org/repos/asf/iotdb-client-go.git

commit 6bf3881de1cd1694da356d022ff82500f2fccd78
Author: HouliangQi <[email protected]>
AuthorDate: Wed Jun 22 15:02:57 2022 +0800

    fix the session id error when reconnect other nodes
---
 client/session.go | 123 ++++++++++++++++++++++++++++++------------------------
 1 file changed, 68 insertions(+), 55 deletions(-)

diff --git a/client/session.go b/client/session.go
index c5b4192..36f2b38 100644
--- a/client/session.go
+++ b/client/session.go
@@ -67,7 +67,6 @@ type endPoint struct {
 }
 
 var endPointList = list.New()
-var session Session
 
 func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) 
error {
        if s.config.FetchSize <= 0 {
@@ -191,8 +190,8 @@ func (s *Session) Close() (r *rpc.TSStatus, err error) {
 func (s *Session) SetStorageGroup(storageGroupId string) (r *rpc.TSStatus, err 
error) {
        r, err = s.client.SetStorageGroup(context.Background(), s.sessionId, 
storageGroupId)
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.SetStorageGroup(context.Background(), s.sessionId, 
storageGroupId)
+               if s.reconnect() {
+                       r, err = s.client.SetStorageGroup(context.Background(), 
s.sessionId, storageGroupId)
                }
        }
        return r, err
@@ -208,8 +207,8 @@ func (s *Session) SetStorageGroup(storageGroupId string) (r 
*rpc.TSStatus, err e
 func (s *Session) DeleteStorageGroup(storageGroupId string) (r *rpc.TSStatus, 
err error) {
        r, err = s.client.DeleteStorageGroups(context.Background(), 
s.sessionId, []string{storageGroupId})
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.DeleteStorageGroups(context.Background(), s.sessionId, 
[]string{storageGroupId})
+               if s.reconnect() {
+                       r, err = 
s.client.DeleteStorageGroups(context.Background(), s.sessionId, 
[]string{storageGroupId})
                }
        }
        return r, err
@@ -225,8 +224,8 @@ func (s *Session) DeleteStorageGroup(storageGroupId string) 
(r *rpc.TSStatus, er
 func (s *Session) DeleteStorageGroups(storageGroupIds ...string) (r 
*rpc.TSStatus, err error) {
        r, err = s.client.DeleteStorageGroups(context.Background(), 
s.sessionId, storageGroupIds)
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.DeleteStorageGroups(context.Background(), s.sessionId, 
storageGroupIds)
+               if s.reconnect() {
+                       r, err = 
s.client.DeleteStorageGroups(context.Background(), s.sessionId, storageGroupIds)
                }
        }
        return r, err
@@ -247,8 +246,9 @@ func (s *Session) CreateTimeseries(path string, dataType 
TSDataType, encoding TS
                Compressor: int32(compressor), Attributes: attributes, Tags: 
tags}
        status, err := s.client.CreateTimeseries(context.Background(), &request)
        if err != nil && status == nil {
-               if reconnect() {
-                       status, err = 
session.client.CreateTimeseries(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       status, err = 
s.client.CreateTimeseries(context.Background(), &request)
                }
        }
        return status, err
@@ -285,8 +285,9 @@ func (s *Session) CreateMultiTimeseries(paths []string, 
dataTypes []TSDataType,
        r, err = s.client.CreateMultiTimeseries(context.Background(), &request)
 
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.CreateMultiTimeseries(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = 
s.client.CreateMultiTimeseries(context.Background(), &request)
                }
        }
 
@@ -303,8 +304,8 @@ func (s *Session) CreateMultiTimeseries(paths []string, 
dataTypes []TSDataType,
 func (s *Session) DeleteTimeseries(paths []string) (r *rpc.TSStatus, err 
error) {
        r, err = s.client.DeleteTimeseries(context.Background(), s.sessionId, 
paths)
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.DeleteTimeseries(context.Background(), s.sessionId, paths)
+               if s.reconnect() {
+                       r, err = 
s.client.DeleteTimeseries(context.Background(), s.sessionId, paths)
                }
        }
        return r, err
@@ -323,8 +324,9 @@ func (s *Session) DeleteData(paths []string, startTime 
int64, endTime int64) (r
        request := rpc.TSDeleteDataReq{SessionId: s.sessionId, Paths: paths, 
StartTime: startTime, EndTime: endTime}
        r, err = s.client.DeleteData(context.Background(), &request)
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.DeleteData(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = s.client.DeleteData(context.Background(), 
&request)
                }
        }
        return r, err
@@ -345,8 +347,9 @@ func (s *Session) InsertStringRecord(deviceId string, 
measurements []string, val
                Values: values, Timestamp: timestamp}
        r, err = s.client.InsertStringRecord(context.Background(), &request)
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.InsertStringRecord(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = 
s.client.InsertStringRecord(context.Background(), &request)
                }
        }
        return r, err
@@ -377,8 +380,9 @@ func (s *Session) ExecuteStatement(sql string) 
(*SessionDataSet, error) {
        resp, err := s.client.ExecuteStatement(context.Background(), &request)
 
        if err != nil && resp == nil {
-               if reconnect() {
-                       resp, err = 
session.client.ExecuteStatement(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       resp, err = 
s.client.ExecuteStatement(context.Background(), &request)
                }
        }
 
@@ -395,8 +399,9 @@ func (s *Session) ExecuteQueryStatement(sql string, 
timeoutMs *int64) (*SessionD
                        return nil, statusErr
                }
        } else {
-               if reconnect() {
-                       resp, err = 
session.client.ExecuteQueryStatement(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       resp, err = 
s.client.ExecuteQueryStatement(context.Background(), &request)
                        if statusErr := VerifySuccess(resp.Status); statusErr 
== nil {
                                return NewSessionDataSet(sql, resp.Columns, 
resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.client, 
s.sessionId, resp.QueryDataSet, resp.IgnoreTimeStamp != nil && 
*resp.IgnoreTimeStamp, s.config.FetchSize, timeoutMs), err
                        } else {
@@ -433,8 +438,9 @@ func (s *Session) InsertRecord(deviceId string, 
measurements []string, dataTypes
        r, err = s.client.InsertRecord(context.Background(), request)
 
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.InsertRecord(context.Background(), request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = s.client.InsertRecord(context.Background(), 
request)
                }
        }
 
@@ -500,8 +506,9 @@ func (s *Session) InsertRecordsOfOneDevice(deviceId string, 
timestamps []int64,
        r, err = s.client.InsertRecordsOfOneDevice(context.Background(), 
request)
 
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.InsertRecordsOfOneDevice(context.Background(), request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = 
s.client.InsertRecordsOfOneDevice(context.Background(), request)
                }
        }
 
@@ -528,8 +535,9 @@ func (s *Session) InsertRecords(deviceIds []string, 
measurements [][]string, dat
        } else {
                r, err = s.client.InsertRecords(context.Background(), request)
                if err != nil && r == nil {
-                       if reconnect() {
-                               r, err = 
session.client.InsertRecords(context.Background(), request)
+                       if s.reconnect() {
+                               request.SessionId = s.sessionId
+                               r, err = 
s.client.InsertRecords(context.Background(), request)
                        }
                }
                return r, err
@@ -555,8 +563,9 @@ func (s *Session) InsertTablets(tablets []*Tablet, sorted 
bool) (r *rpc.TSStatus
        }
        r, err = s.client.InsertTablets(context.Background(), request)
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.InsertTablets(context.Background(), request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = s.client.InsertTablets(context.Background(), 
request)
                }
        }
        return r, err
@@ -569,8 +578,9 @@ func (s *Session) ExecuteBatchStatement(inserts []string) 
(r *rpc.TSStatus, err
        }
        r, err = s.client.ExecuteBatchStatement(context.Background(), &request)
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.ExecuteBatchStatement(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = 
s.client.ExecuteBatchStatement(context.Background(), &request)
                }
        }
        return r, err
@@ -588,8 +598,9 @@ func (s *Session) ExecuteRawDataQuery(paths []string, 
startTime int64, endTime i
        resp, err := s.client.ExecuteRawDataQuery(context.Background(), 
&request)
 
        if err != nil && resp == nil {
-               if reconnect() {
-                       resp, err = 
session.client.ExecuteRawDataQuery(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       resp, err = 
s.client.ExecuteRawDataQuery(context.Background(), &request)
                }
        }
 
@@ -606,8 +617,9 @@ func (s *Session) ExecuteUpdateStatement(sql string) 
(*SessionDataSet, error) {
        resp, err := s.client.ExecuteUpdateStatement(context.Background(), 
&request)
 
        if err != nil && resp == nil {
-               if reconnect() {
-                       resp, err = 
session.client.ExecuteUpdateStatement(context.Background(), &request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       resp, err = 
s.client.ExecuteUpdateStatement(context.Background(), &request)
                }
        }
 
@@ -753,8 +765,9 @@ func (s *Session) InsertTablet(tablet *Tablet, sorted bool) 
(r *rpc.TSStatus, er
        r, err = s.client.InsertTablet(context.Background(), request)
 
        if err != nil && r == nil {
-               if reconnect() {
-                       r, err = 
session.client.InsertTablet(context.Background(), request)
+               if s.reconnect() {
+                       request.SessionId = s.sessionId
+                       r, err = s.client.InsertTablet(context.Background(), 
request)
                }
        }
 
@@ -791,7 +804,7 @@ func NewSession(config *Config) Session {
 }
 
 func NewClusterSession(ClusterConfig *ClusterConfig) Session {
-
+       session := Session{}
        node := endPoint{}
        for i := 0; i < len(ClusterConfig.NodeUrls); i++ {
                node.Host = strings.Split(ClusterConfig.NodeUrls[i], ":")[0]
@@ -823,16 +836,16 @@ func NewClusterSession(ClusterConfig *ClusterConfig) 
Session {
        return session
 }
 
-func initClusterConn(node endPoint) error {
+func (s *Session) initClusterConn(node endPoint) error {
        var err error
 
-       session.trans, err = thrift.NewTSocketConf(net.JoinHostPort(node.Host, 
node.Port), &thrift.TConfiguration{
+       s.trans, err = thrift.NewTSocketConf(net.JoinHostPort(node.Host, 
node.Port), &thrift.TConfiguration{
                ConnectTimeout: time.Duration(0), // Use 0 for no timeout
        })
        if err == nil {
-               session.trans = thrift.NewTFramedTransport(session.trans)
-               if !session.trans.IsOpen() {
-                       err = session.trans.Open()
+               s.trans = thrift.NewTFramedTransport(s.trans)
+               if !s.trans.IsOpen() {
+                       err = s.trans.Open()
                        if err != nil {
                                return err
                        }
@@ -840,24 +853,24 @@ func initClusterConn(node endPoint) error {
        }
        var protocolFactory thrift.TProtocolFactory
        protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
-       iprot := protocolFactory.GetProtocol(session.trans)
-       oprot := protocolFactory.GetProtocol(session.trans)
-       session.client = 
rpc.NewTSIServiceClient(thrift.NewTStandardClient(iprot, oprot))
-       req := rpc.TSOpenSessionReq{ClientProtocol: 
rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: 
session.config.TimeZone, Username: &session.config.UserName,
-               Password: &session.config.Password}
+       iprot := protocolFactory.GetProtocol(s.trans)
+       oprot := protocolFactory.GetProtocol(s.trans)
+       s.client = rpc.NewTSIServiceClient(thrift.NewTStandardClient(iprot, 
oprot))
+       req := rpc.TSOpenSessionReq{ClientProtocol: 
rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, 
Username: &s.config.UserName,
+               Password: &s.config.Password}
        fmt.Println(req)
-       resp, err := session.client.OpenSession(context.Background(), &req)
+       resp, err := s.client.OpenSession(context.Background(), &req)
        if err != nil {
                return err
        }
-       session.sessionId = resp.GetSessionId()
-       session.requestStatementId, err = 
session.client.RequestStatementId(context.Background(), session.sessionId)
+       s.sessionId = resp.GetSessionId()
+       s.requestStatementId, err = 
s.client.RequestStatementId(context.Background(), s.sessionId)
        if err != nil {
                return err
        }
 
-       session.SetTimeZone(session.config.TimeZone)
-       session.config.TimeZone, err = session.GetTimeZone()
+       s.SetTimeZone(s.config.TimeZone)
+       s.config.TimeZone, err = s.GetTimeZone()
        return err
 
 }
@@ -873,13 +886,13 @@ func getConfig(host string, port string, userName string, 
passWord string, fetch
        }
 }
 
-func reconnect() bool {
+func (s *Session) reconnect() bool {
        var err error
        var connectedSuccess = false
 
        for i := 0; i < 3; i++ {
                for e := endPointList.Front(); e != nil; e = e.Next() {
-                       err = initClusterConn(e.Value.(endPoint))
+                       err = s.initClusterConn(e.Value.(endPoint))
                        if err == nil {
                                connectedSuccess = true
                                break

Reply via email to