This is an automated email from the ASF dual-hosted git repository.

haonan pushed a commit to branch fix_python_reconnect
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 172196e3e4e70e683d321e7baefeeb37a68cda68
Author: HTHou <[email protected]>
AuthorDate: Fri Jun 9 11:40:46 2023 +0800

    Fix reconnect logic of python client
---
 iotdb-client/client-py/iotdb/Session.py            | 20 +++++---
 .../client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py | 56 +++++++++++-----------
 2 files changed, 41 insertions(+), 35 deletions(-)

diff --git a/iotdb-client/client-py/iotdb/Session.py 
b/iotdb-client/client-py/iotdb/Session.py
index c9e5f3902d1..310db405679 100644
--- a/iotdb-client/client-py/iotdb/Session.py
+++ b/iotdb-client/client-py/iotdb/Session.py
@@ -142,11 +142,16 @@ class Session(object):
                     self.__default_connection = self.init_connection(
                         self.__default_endpoint
                     )
-                except Exception:
+                except Exception as e:
                     if not self.reconnect():
-                        raise IoTDBConnectionException(
-                            "Cluster has no nodes to connect"
-                        ) from None
+                        if str(e).startswith("Could not connect to any of"):
+                            error_msg = (
+                                "Cluster has no nodes to connect because: "
+                                + self.connection_error_msg()
+                            )
+                        else:
+                            error_msg = str(e)
+                        raise IoTDBConnectionException(error_msg) from None
                 break
         self.__client = self.__default_connection.client
         self.__session_id = self.__default_connection.session_id
@@ -1739,9 +1744,10 @@ class Session(object):
                 and self.__default_connection.transport is not None
             ):
                 self.__default_connection.transport.close()
-            curr_host_index = random.randint(0, len(self.__hosts))
+            curr_host_index = random.randint(0, len(self.__hosts) - 1)
             try_host_num = 0
-            for j in range(curr_host_index, len(self.__hosts)):
+            j = curr_host_index
+            while j < len(self.__hosts):
                 if try_host_num == len(self.__hosts):
                     break
                 self.__default_endpoint = TEndPoint(self.__hosts[j], 
self.__ports[j])
@@ -1762,6 +1768,8 @@ class Session(object):
                         }
                 except IoTDBConnectionException:
                     pass
+                    j += 1
+                    continue
                 break
             if connected:
                 break
diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py 
b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py
index 8d966febd69..008a314e683 100644
--- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py
+++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py
@@ -66,9 +66,7 @@ class IoTDBSQLCompiler(SQLCompiler):
 
         kwargs["within_columns_clause"] = False
 
-        compile_state = select_stmt._compile_state_factory(
-            select_stmt, self, **kwargs
-        )
+        compile_state = select_stmt._compile_state_factory(select_stmt, self, 
**kwargs)
         select_stmt = compile_state.statement
 
         toplevel = not self.stack
@@ -101,9 +99,9 @@ class IoTDBSQLCompiler(SQLCompiler):
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
         populate_result_map = need_column_expressions = (
-                toplevel
-                or entry.get("need_result_map_for_compound", False)
-                or entry.get("need_result_map_for_nested", False)
+            toplevel
+            or entry.get("need_result_map_for_compound", False)
+            or entry.get("need_result_map_for_nested", False)
         )
 
         # indicates there is a CompoundSelect in play and we are not the
@@ -181,22 +179,22 @@ class IoTDBSQLCompiler(SQLCompiler):
                     [
                         name
                         for (
-                        key,
-                        proxy_name,
-                        fallback_label_name,
-                        name,
-                        repeated,
-                    ) in compile_state.columns_plus_names
+                            key,
+                            proxy_name,
+                            fallback_label_name,
+                            name,
+                            repeated,
+                        ) in compile_state.columns_plus_names
                     ],
                     [
                         name
                         for (
-                        key,
-                        proxy_name,
-                        fallback_label_name,
-                        name,
-                        repeated,
-                    ) in compile_state_wraps_for.columns_plus_names
+                            key,
+                            proxy_name,
+                            fallback_label_name,
+                            name,
+                            repeated,
+                        ) in compile_state_wraps_for.columns_plus_names
                     ],
                 )
             )
@@ -236,18 +234,18 @@ class IoTDBSQLCompiler(SQLCompiler):
         inner_columns = list(
             filter(
                 lambda x: "Time"
-                          not in x.replace(self.preparer.initial_quote, 
"").split(),
+                not in x.replace(self.preparer.initial_quote, "").split(),
                 inner_columns,
             )
         )
 
         if inner_columns and time_column_index:
             inner_columns[-1] = (
-                    inner_columns[-1]
-                    + " \n FROM Time Index "
-                    + " ".join(time_column_index)
-                    + " \n FROM Time Name "
-                    + " ".join(time_column_names)
+                inner_columns[-1]
+                + " \n FROM Time Index "
+                + " ".join(time_column_index)
+                + " \n FROM Time Name "
+                + " ".join(time_column_names)
             )
 
         text = self._compose_select_body(
@@ -274,11 +272,11 @@ class IoTDBSQLCompiler(SQLCompiler):
         if self.ctes and (not is_embedded_select or toplevel):
             nesting_level = len(self.stack) if not toplevel else None
             text = (
-                    self._render_cte_clause(
-                        nesting_level=nesting_level,
-                        visiting_cte=kwargs.get("visiting_cte"),
-                    )
-                    + text
+                self._render_cte_clause(
+                    nesting_level=nesting_level,
+                    visiting_cte=kwargs.get("visiting_cte"),
+                )
+                + text
             )
 
         if select_stmt._suffixes:

Reply via email to