This is an automated email from the ASF dual-hosted git repository.

jiangtian pushed a commit to branch dev/1.3
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/dev/1.3 by this push:
     new c01d6b9f0fb [To dev/1.3] Add connection_time_out_in_ms for Python SDK 
(#14919)
c01d6b9f0fb is described below

commit c01d6b9f0fbf4b08ce594a055f7684ff834e62a9
Author: Haonan <[email protected]>
AuthorDate: Fri Feb 21 15:12:06 2025 +0800

    [To dev/1.3] Add connection_time_out_in_ms for Python SDK (#14919)
    
    * Add connection_time_out_in_ms for Python SDK (#14898)
    
    * Add connection_time_out_in_ms for Python SDK
    
    * fix bug
    
    * fix bug
    
    * remove line
---
 iotdb-client/client-py/iotdb/Session.py     | 63 ++++++++++++++++-------------
 iotdb-client/client-py/iotdb/SessionPool.py |  4 ++
 2 files changed, 38 insertions(+), 29 deletions(-)

diff --git a/iotdb-client/client-py/iotdb/Session.py 
b/iotdb-client/client-py/iotdb/Session.py
index 90018e06379..cece2ae474e 100644
--- a/iotdb-client/client-py/iotdb/Session.py
+++ b/iotdb-client/client-py/iotdb/Session.py
@@ -18,13 +18,12 @@
 
 import logging
 import random
-import struct
 import sys
-import ssl
+import struct
 import time
 import warnings
 from thrift.protocol import TBinaryProtocol, TCompactProtocol
-from thrift.transport import TSocket, TTransport, TSSLSocket
+from thrift.transport import TSocket, TTransport
 
 from iotdb.utils.SessionDataSet import SessionDataSet
 from .template.Template import Template
@@ -88,6 +87,7 @@ class Session(object):
         enable_redirection=True,
         use_ssl=False,
         ca_certs=None,
+        connection_timeout_in_ms=None,
     ):
         self.__host = host
         self.__port = port
@@ -110,6 +110,7 @@ class Session(object):
         self.__endpoint_to_connection = None
         self.__use_ssl = use_ssl
         self.__ca_certs = ca_certs
+        self.__connection_timeout_in_ms = connection_timeout_in_ms
 
     @classmethod
     def init_from_node_urls(
@@ -122,6 +123,7 @@ class Session(object):
         enable_redirection=True,
         use_ssl=False,
         ca_certs=None,
+        connection_timeout_in_ms=None,
     ):
         if node_urls is None:
             raise RuntimeError("node urls is empty")
@@ -135,6 +137,7 @@ class Session(object):
             enable_redirection,
             use_ssl=use_ssl,
             ca_certs=ca_certs,
+            connection_timeout_in_ms=connection_timeout_in_ms,
         )
         session.__hosts = []
         session.__ports = []
@@ -182,32 +185,7 @@ class Session(object):
             }
 
     def init_connection(self, endpoint):
-        try:
-            if self.__use_ssl:
-                if sys.version_info >= (3, 10):
-                    context = 
ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
-                else:
-                    context = ssl.SSLContext(ssl.PROTOCOL_TLS)
-                    context.verify_mode = ssl.CERT_REQUIRED
-                    context.check_hostname = True
-                context.load_verify_locations(cafile=self.__ca_certs)
-                socket = TSSLSocket.TSSLSocket(
-                    host=endpoint.ip, port=endpoint.port, ssl_context=context
-                )
-            else:
-                socket = TSocket.TSocket(endpoint.ip, endpoint.port)
-            transport = TTransport.TFramedTransport(socket)
-
-            if not transport.isOpen():
-                try:
-                    transport.open()
-                except TTransport.TTransportException as e:
-                    raise IoTDBConnectionException(e) from None
-        except ssl.SSLError as e:
-            print(f"SSL error occurred: {e}")
-        except Exception as e:
-            print(f"An unexpected error occurred: {e}")
-
+        transport = self.__get_transport(endpoint)
         if self.__enable_rpc_compression:
             client = 
Client(TCompactProtocol.TCompactProtocolAccelerated(transport))
         else:
@@ -254,6 +232,33 @@ class Session(object):
             self.__zone_id = self.get_time_zone()
         return SessionConnection(client, transport, session_id, statement_id)
 
+    def __get_transport(self, endpoint):
+        if self.__use_ssl:
+            import ssl
+            from thrift.transport import TSSLSocket
+
+            if sys.version_info >= (3, 10):
+                context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
+            else:
+                context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+                context.verify_mode = ssl.CERT_REQUIRED
+                context.check_hostname = True
+            context.load_verify_locations(cafile=self.__ca_certs)
+            socket = TSSLSocket.TSSLSocket(
+                host=endpoint.ip, port=endpoint.port, ssl_context=context
+            )
+        else:
+            socket = TSocket.TSocket(endpoint.ip, endpoint.port)
+        socket.setTimeout(self.__connection_timeout_in_ms)
+        transport = TTransport.TFramedTransport(socket)
+
+        if not transport.isOpen():
+            try:
+                transport.open()
+            except TTransport.TTransportException as e:
+                raise IoTDBConnectionException(e) from None
+        return transport
+
     def is_open(self):
         return not self.__is_close
 
diff --git a/iotdb-client/client-py/iotdb/SessionPool.py 
b/iotdb-client/client-py/iotdb/SessionPool.py
index f6d74e16515..85cd17946d6 100644
--- a/iotdb-client/client-py/iotdb/SessionPool.py
+++ b/iotdb-client/client-py/iotdb/SessionPool.py
@@ -44,6 +44,7 @@ class PoolConfig(object):
         enable_compression: bool = False,
         use_ssl: bool = False,
         ca_certs: str = None,
+        connection_timeout_in_ms: int = None,
     ):
         self.host = host
         self.port = port
@@ -62,6 +63,7 @@ class PoolConfig(object):
         self.enable_compression = enable_compression
         self.use_ssl = use_ssl
         self.ca_certs = ca_certs
+        self.connection_timeout_in_ms = connection_timeout_in_ms
 
 
 class SessionPool(object):
@@ -86,6 +88,7 @@ class SessionPool(object):
                 self.__pool_config.time_zone,
                 use_ssl=self.__pool_config.use_ssl,
                 ca_certs=self.__pool_config.ca_certs,
+                
connection_timeout_in_ms=self.__pool_config.connection_timeout_in_ms,
             )
 
         else:
@@ -98,6 +101,7 @@ class SessionPool(object):
                 self.__pool_config.time_zone,
                 use_ssl=self.__pool_config.use_ssl,
                 ca_certs=self.__pool_config.ca_certs,
+                
connection_timeout_in_ms=self.__pool_config.connection_timeout_in_ms,
             )
 
         session.open(self.__pool_config.enable_compression)

Reply via email to