This is an automated email from the ASF dual-hosted git repository.

xingtanzjr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 1d1b52cccb8 Optimize result set iteration in Python Client (#11231)
1d1b52cccb8 is described below

commit 1d1b52cccb8945089f8801945ac5c92b22f298dc
Author: Haonan <[email protected]>
AuthorDate: Thu Oct 12 20:45:30 2023 -0500

    Optimize result set iteration in Python Client (#11231)
---
 iotdb-client/client-py/iotdb/utils/Field.py        | 148 +++++-----
 .../client-py/iotdb/utils/IoTDBRpcDataSet.py       | 299 ++++++++++-----------
 .../client-py/iotdb/utils/SessionDataSet.py        |  90 +++----
 iotdb-client/client-py/tests/test_tablet.py        |  18 +-
 4 files changed, 273 insertions(+), 282 deletions(-)

diff --git a/iotdb-client/client-py/iotdb/utils/Field.py 
b/iotdb-client/client-py/iotdb/utils/Field.py
index 0756b1c49d6..913c7594bf5 100644
--- a/iotdb-client/client-py/iotdb/utils/Field.py
+++ b/iotdb-client/client-py/iotdb/utils/Field.py
@@ -18,20 +18,17 @@
 
 # for package
 from .IoTDBConstants import TSDataType
+import numpy as np
+import pandas as pd
 
 
 class Field(object):
-    def __init__(self, data_type):
+    def __init__(self, data_type, value=None):
         """
         :param data_type: TSDataType
         """
         self.__data_type = data_type
-        self.__bool_value = None
-        self.__int_value = None
-        self.__long_value = None
-        self.__float_value = None
-        self.__double_value = None
-        self.__binary_value = None
+        self.value = value
 
     @staticmethod
     def copy(field):
@@ -59,73 +56,99 @@ class Field(object):
         return self.__data_type
 
     def is_null(self):
-        return self.__data_type is None
+        return self.__data_type is None or self.value is None or self.value is 
pd.NA
 
-    def set_bool_value(self, value):
-        self.__bool_value = value
+    def set_bool_value(self, value: bool):
+        self.value = value
 
     def get_bool_value(self):
         if self.__data_type is None:
             raise Exception("Null Field Exception!")
-        return self.__bool_value
+        if (
+            self.__data_type != TSDataType.BOOLEAN
+            or self.value is None
+            or self.value is pd.NA
+        ):
+            return None
+        return self.value
 
-    def set_int_value(self, value):
-        self.__int_value = value
+    def set_int_value(self, value: int):
+        self.value = value
 
     def get_int_value(self):
         if self.__data_type is None:
             raise Exception("Null Field Exception!")
-        return self.__int_value
+        if (
+            self.__data_type != TSDataType.INT32
+            or self.value is None
+            or self.value is pd.NA
+        ):
+            return None
+        return np.int32(self.value)
 
-    def set_long_value(self, value):
-        self.__long_value = value
+    def set_long_value(self, value: int):
+        self.value = value
 
     def get_long_value(self):
         if self.__data_type is None:
             raise Exception("Null Field Exception!")
-        return self.__long_value
+        if (
+            self.__data_type != TSDataType.INT64
+            or self.value is None
+            or self.value is pd.NA
+        ):
+            return None
+        return np.int64(self.value)
 
-    def set_float_value(self, value):
-        self.__float_value = value
+    def set_float_value(self, value: float):
+        self.value = value
 
     def get_float_value(self):
         if self.__data_type is None:
             raise Exception("Null Field Exception!")
-        return self.__float_value
+        if (
+            self.__data_type != TSDataType.FLOAT
+            or self.value is None
+            or self.value is pd.NA
+        ):
+            return None
+        return np.float32(self.value)
 
-    def set_double_value(self, value):
-        self.__double_value = value
+    def set_double_value(self, value: float):
+        self.value = value
 
     def get_double_value(self):
         if self.__data_type is None:
             raise Exception("Null Field Exception!")
-        return self.__double_value
+        if (
+            self.__data_type != TSDataType.DOUBLE
+            or self.value is None
+            or self.value is pd.NA
+        ):
+            return None
+        return np.float64(self.value)
 
-    def set_binary_value(self, value):
-        self.__binary_value = value
+    def set_binary_value(self, value: bytes):
+        self.value = value
 
     def get_binary_value(self):
         if self.__data_type is None:
             raise Exception("Null Field Exception!")
-        return self.__binary_value
+        if (
+            self.__data_type != TSDataType.TEXT
+            or self.value is None
+            or self.value is pd.NA
+        ):
+            return None
+        return self.value
 
     def get_string_value(self):
-        if self.__data_type is None:
+        if self.__data_type is None or self.value is None or self.value is 
pd.NA:
             return "None"
-        elif self.__data_type == TSDataType.BOOLEAN:
-            return str(self.__bool_value)
-        elif self.__data_type == TSDataType.INT64:
-            return str(self.__long_value)
-        elif self.__data_type == TSDataType.INT32:
-            return str(self.__int_value)
-        elif self.__data_type == TSDataType.FLOAT:
-            return str(self.__float_value)
-        elif self.__data_type == TSDataType.DOUBLE:
-            return str(self.__double_value)
-        elif self.__data_type == TSDataType.TEXT:
-            return self.__binary_value.decode("utf-8")
+        elif self.__data_type == 5:
+            return self.value.decode("utf-8")
         else:
-            raise Exception("unsupported data type 
{}".format(self.__data_type))
+            return str(self.get_object_value(self.__data_type))
 
     def __str__(self):
         return self.get_string_value()
@@ -134,22 +157,19 @@ class Field(object):
         """
         :param data_type: TSDataType
         """
-        if self.__data_type is None:
+        if self.__data_type is None or self.value is None or self.value is 
pd.NA:
             return None
-        elif data_type == TSDataType.BOOLEAN:
-            return self.get_bool_value()
-        elif data_type == TSDataType.INT32:
-            return self.get_int_value()
-        elif data_type == TSDataType.INT64:
-            return self.get_long_value()
-        elif data_type == TSDataType.FLOAT:
-            return self.get_float_value()
-        elif data_type == TSDataType.DOUBLE:
-            return self.get_double_value()
-        elif data_type == TSDataType.TEXT:
-            return self.get_binary_value()
-        else:
-            raise Exception("unsupported data type {}".format(data_type))
+        if data_type == 0:
+            return bool(self.value)
+        elif data_type == 1:
+            return np.int32(self.value)
+        elif data_type == 2:
+            return np.int64(self.value)
+        elif data_type == 3:
+            return np.float32(self.value)
+        elif data_type == 4:
+            return np.float64(self.value)
+        return self.value
 
     @staticmethod
     def get_field(value, data_type):
@@ -157,21 +177,7 @@ class Field(object):
         :param value: field value corresponding to the data type
         :param data_type: TSDataType
         """
-        if value is None:
+        if value is None or value is pd.NA:
             return None
-        field = Field(data_type)
-        if data_type == TSDataType.BOOLEAN:
-            field.set_bool_value(value)
-        elif data_type == TSDataType.INT32:
-            field.set_int_value(value)
-        elif data_type == TSDataType.INT64:
-            field.set_long_value(value)
-        elif data_type == TSDataType.FLOAT:
-            field.set_float_value(value)
-        elif data_type == TSDataType.DOUBLE:
-            field.set_double_value(value)
-        elif data_type == TSDataType.TEXT:
-            field.set_binary_value(value)
-        else:
-            raise Exception("unsupported data type {}".format(data_type))
+        field = Field(data_type, value)
         return field
diff --git a/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py 
b/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py
index 564a0377341..53ef09b8e7a 100644
--- a/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py
+++ b/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py
@@ -29,11 +29,14 @@ from iotdb.utils.IoTDBConstants import TSDataType
 logger = logging.getLogger("IoTDB")
 
 
+def _to_bitbuffer(b):
+    return bytes("{:0{}b}".format(int(binascii.hexlify(b), 16), 8 * len(b)), 
"utf-8")
+
+
 class IoTDBRpcDataSet(object):
     TIMESTAMP_STR = "Time"
     # VALUE_IS_NULL = "The value got by %s (column name) is NULL."
     START_INDEX = 2
-    FLAG = 0x80
 
     def __init__(
         self,
@@ -51,62 +54,55 @@ class IoTDBRpcDataSet(object):
     ):
         self.__statement_id = statement_id
         self.__session_id = session_id
-        self.__ignore_timestamp = ignore_timestamp
+        self.ignore_timestamp = ignore_timestamp
         self.__sql = sql
         self.__query_id = query_id
         self.__client = client
         self.__fetch_size = fetch_size
-        self.__column_size = len(column_name_list)
+        self.column_size = len(column_name_list)
         self.__default_time_out = 1000
 
         self.__column_name_list = []
         self.__column_type_list = []
-        self.__column_ordinal_dict = {}
+        self.column_ordinal_dict = {}
         if not ignore_timestamp:
             self.__column_name_list.append(IoTDBRpcDataSet.TIMESTAMP_STR)
             self.__column_type_list.append(TSDataType.INT64)
-            self.__column_ordinal_dict[IoTDBRpcDataSet.TIMESTAMP_STR] = 1
+            self.column_ordinal_dict[IoTDBRpcDataSet.TIMESTAMP_STR] = 1
 
         if column_name_index is not None:
-            self.__column_type_deduplicated_list = [
+            self.column_type_deduplicated_list = [
                 None for _ in range(len(column_name_index))
             ]
-            for i in range(len(column_name_list)):
+            for i in range(self.column_size):
                 name = column_name_list[i]
                 self.__column_name_list.append(name)
                 self.__column_type_list.append(TSDataType[column_type_list[i]])
-                if name not in self.__column_ordinal_dict:
+                if name not in self.column_ordinal_dict:
                     index = column_name_index[name]
-                    self.__column_ordinal_dict[name] = (
-                        index + IoTDBRpcDataSet.START_INDEX
-                    )
-                    self.__column_type_deduplicated_list[index] = TSDataType[
+                    self.column_ordinal_dict[name] = index + 
IoTDBRpcDataSet.START_INDEX
+                    self.column_type_deduplicated_list[index] = TSDataType[
                         column_type_list[i]
                     ]
         else:
             index = IoTDBRpcDataSet.START_INDEX
-            self.__column_type_deduplicated_list = []
+            self.column_type_deduplicated_list = []
             for i in range(len(column_name_list)):
                 name = column_name_list[i]
                 self.__column_name_list.append(name)
                 self.__column_type_list.append(TSDataType[column_type_list[i]])
-                if name not in self.__column_ordinal_dict:
-                    self.__column_ordinal_dict[name] = index
+                if name not in self.column_ordinal_dict:
+                    self.column_ordinal_dict[name] = index
                     index += 1
-                    self.__column_type_deduplicated_list.append(
+                    self.column_type_deduplicated_list.append(
                         TSDataType[column_type_list[i]]
                     )
-
-        self.__time_bytes = bytes(0)
-        self.__current_bitmap = [
-            bytes(0) for _ in range(len(self.__column_type_deduplicated_list))
-        ]
-        self.__value = [None for _ in 
range(len(self.__column_type_deduplicated_list))]
         self.__query_data_set = query_data_set
         self.__is_closed = False
         self.__empty_resultSet = False
-        self.__has_cached_record = False
         self.__rows_index = 0
+        self.has_cached_data_frame = False
+        self.data_frame = None
 
     def close(self):
         if self.__is_closed:
@@ -132,23 +128,120 @@ class IoTDBRpcDataSet(object):
             self.__client = None
 
     def next(self):
-        if self.has_cached_result():
-            self.construct_one_row()
+        if not self.has_cached_data_frame:
+            self.construct_one_data_frame()
+        if self.has_cached_data_frame:
             return True
         if self.__empty_resultSet:
             return False
         if self.fetch_results():
-            self.construct_one_row()
+            self.construct_one_data_frame()
             return True
         return False
 
-    def has_cached_result(self):
-        return (self.__query_data_set is not None) and (
-            len(self.__query_data_set.time) != 0
+    def construct_one_data_frame(self):
+        if (
+            self.has_cached_data_frame
+            or self.__query_data_set is None
+            or len(self.__query_data_set.time) == 0
+        ):
+            return
+        result = {}
+        time_array = np.frombuffer(
+            self.__query_data_set.time, np.dtype(np.longlong).newbyteorder(">")
         )
+        if time_array.dtype.byteorder == ">":
+            time_array = time_array.byteswap().newbyteorder("<")
+        result[0] = time_array
+        total_length = len(time_array)
+        for i in range(self.column_size):
+            if self.ignore_timestamp is True:
+                column_name = self.__column_name_list[i]
+            else:
+                column_name = self.__column_name_list[i + 1]
+
+            location = (
+                self.column_ordinal_dict[column_name] - 
IoTDBRpcDataSet.START_INDEX
+            )
+            if location < 0:
+                continue
+            data_type = self.column_type_deduplicated_list[location]
+            value_buffer = self.__query_data_set.valueList[location]
+            value_buffer_len = len(value_buffer)
+            if data_type == 4:
+                data_array = np.frombuffer(
+                    value_buffer, np.dtype(np.double).newbyteorder(">")
+                )
+            elif data_type == 3:
+                data_array = np.frombuffer(
+                    value_buffer, np.dtype(np.float32).newbyteorder(">")
+                )
+            elif data_type == 0:
+                data_array = np.frombuffer(value_buffer, np.dtype("?"))
+            elif data_type == 1:
+                data_array = np.frombuffer(
+                    value_buffer, np.dtype(np.int32).newbyteorder(">")
+                )
+            elif data_type == 2:
+                data_array = np.frombuffer(
+                    value_buffer, np.dtype(np.int64).newbyteorder(">")
+                )
+            elif data_type == 5:
+                j = 0
+                offset = 0
+                data_array = []
+                while offset < value_buffer_len:
+                    length = int.from_bytes(
+                        value_buffer[offset : offset + 4],
+                        byteorder="big",
+                        signed=False,
+                    )
+                    offset += 4
+                    value = bytes(value_buffer[offset : offset + length])
+                    data_array.append(value)
+                    j += 1
+                    offset += length
+                data_array = np.array(data_array, dtype=object)
+            else:
+                raise RuntimeError("unsupported data type 
{}.".format(data_type))
+            if data_array.dtype.byteorder == ">":
+                data_array = data_array.byteswap().newbyteorder("<")
+            # self.__query_data_set.valueList[location] = None
+            if len(data_array) < total_length:
+                # INT32 or INT64 or boolean
+                if data_type == 0 or data_type == 1 or data_type == 2:
+                    tmp_array = np.full(total_length, np.nan, np.float32)
+                else:
+                    tmp_array = np.full(total_length, None, dtype=object)
+
+                bitmap_buffer = self.__query_data_set.bitmapList[location]
+                buffer = _to_bitbuffer(bitmap_buffer)
+                bit_mask = (np.frombuffer(buffer, "u1") - 
ord("0")).astype(bool)
+                if len(bit_mask) != total_length:
+                    bit_mask = bit_mask[:total_length]
+                tmp_array[bit_mask] = data_array
+
+                if data_type == 1:
+                    tmp_array = pd.Series(tmp_array, dtype="Int32")
+                elif data_type == 2:
+                    tmp_array = pd.Series(tmp_array, dtype="Int64")
+                elif data_type == 0:
+                    tmp_array = pd.Series(tmp_array, dtype="boolean")
+                data_array = tmp_array
+
+            result[i + 1] = data_array
+        self.__query_data_set = None
+        self.data_frame = pd.DataFrame(result, dtype=object)
+        if not self.data_frame.empty:
+            self.has_cached_data_frame = True
+
+    def has_cached_result(self):
+        return self.has_cached_data_frame
 
     def _has_next_result_set(self):
-        if self.has_cached_result():
+        if (self.__query_data_set is not None) and (
+            len(self.__query_data_set.time) != 0
+        ):
             return True
         if self.__empty_resultSet:
             return False
@@ -156,11 +249,6 @@ class IoTDBRpcDataSet(object):
             return True
         return False
 
-    def _to_bitbuffer(self, b):
-        return bytes(
-            "{:0{}b}".format(int(binascii.hexlify(b), 16), 8 * len(b)), "utf-8"
-        )
-
     def resultset_to_pandas(self):
         result = {}
         for column_name in self.__column_name_list:
@@ -171,51 +259,45 @@ class IoTDBRpcDataSet(object):
             )
             if time_array.dtype.byteorder == ">":
                 time_array = time_array.byteswap().newbyteorder("<")
-            if (
-                self.get_ignore_timestamp() is None
-                or self.get_ignore_timestamp() is False
-            ):
+            if self.ignore_timestamp is None or self.ignore_timestamp is False:
                 result[IoTDBRpcDataSet.TIMESTAMP_STR].append(time_array)
 
             self.__query_data_set.time = []
             total_length = len(time_array)
 
             for i in range(len(self.__query_data_set.bitmapList)):
-                if self.get_ignore_timestamp() is True:
-                    column_name = self.get_column_names()[i]
+                if self.ignore_timestamp is True:
+                    column_name = self.__column_name_list[i]
                 else:
-                    column_name = self.get_column_names()[i + 1]
+                    column_name = self.__column_name_list[i + 1]
 
                 location = (
-                    self.__column_ordinal_dict[column_name]
-                    - IoTDBRpcDataSet.START_INDEX
+                    self.column_ordinal_dict[column_name] - 
IoTDBRpcDataSet.START_INDEX
                 )
                 if location < 0:
                     continue
-                data_type = self.__column_type_deduplicated_list[location]
+                data_type = self.column_type_deduplicated_list[location]
                 value_buffer = self.__query_data_set.valueList[location]
                 value_buffer_len = len(value_buffer)
-
-                data_array = None
-                if data_type == TSDataType.DOUBLE:
+                if data_type == 4:
                     data_array = np.frombuffer(
                         value_buffer, np.dtype(np.double).newbyteorder(">")
                     )
-                elif data_type == TSDataType.FLOAT:
+                elif data_type == 3:
                     data_array = np.frombuffer(
                         value_buffer, np.dtype(np.float32).newbyteorder(">")
                     )
-                elif data_type == TSDataType.BOOLEAN:
+                elif data_type == 0:
                     data_array = np.frombuffer(value_buffer, np.dtype("?"))
-                elif data_type == TSDataType.INT32:
+                elif data_type == 1:
                     data_array = np.frombuffer(
                         value_buffer, np.dtype(np.int32).newbyteorder(">")
                     )
-                elif data_type == TSDataType.INT64:
+                elif data_type == 2:
                     data_array = np.frombuffer(
                         value_buffer, np.dtype(np.int64).newbyteorder(">")
                     )
-                elif data_type == TSDataType.TEXT:
+                elif data_type == 5:
                     j = 0
                     offset = 0
                     data_array = []
@@ -226,7 +308,7 @@ class IoTDBRpcDataSet(object):
                             signed=False,
                         )
                         offset += 4
-                        value_bytes = value_buffer[offset : offset + length]
+                        value_bytes = bytes(value_buffer[offset : offset + 
length])
                         value = value_bytes.decode("utf-8")
                         data_array.append(value)
                         j += 1
@@ -237,31 +319,29 @@ class IoTDBRpcDataSet(object):
                 if data_array.dtype.byteorder == ">":
                     data_array = data_array.byteswap().newbyteorder("<")
                 self.__query_data_set.valueList[location] = None
-
+                tmp_array = []
                 if len(data_array) < total_length:
-                    if data_type == TSDataType.INT32 or data_type == 
TSDataType.INT64:
+                    if data_type == 1 or data_type == 2:
                         tmp_array = np.full(total_length, np.nan, np.float32)
-                    elif (
-                        data_type == TSDataType.FLOAT or data_type == 
TSDataType.DOUBLE
-                    ):
+                    elif data_type == 3 or data_type == 4:
                         tmp_array = np.full(total_length, np.nan, 
data_array.dtype)
-                    elif data_type == TSDataType.BOOLEAN:
+                    elif data_type == 0:
                         tmp_array = np.full(total_length, np.nan, np.float32)
-                    elif data_type == TSDataType.TEXT:
+                    elif data_type == 5:
                         tmp_array = np.full(total_length, None, 
dtype=data_array.dtype)
 
                     bitmap_buffer = self.__query_data_set.bitmapList[location]
-                    buffer = self._to_bitbuffer(bitmap_buffer)
+                    buffer = _to_bitbuffer(bitmap_buffer)
                     bit_mask = (np.frombuffer(buffer, "u1") - 
ord("0")).astype(bool)
                     if len(bit_mask) != total_length:
                         bit_mask = bit_mask[:total_length]
                     tmp_array[bit_mask] = data_array
 
-                    if data_type == TSDataType.INT32:
+                    if data_type == 1:
                         tmp_array = pd.Series(tmp_array).astype("Int32")
-                    elif data_type == TSDataType.INT64:
+                    elif data_type == 2:
                         tmp_array = pd.Series(tmp_array).astype("Int64")
-                    elif data_type == TSDataType.BOOLEAN:
+                    elif data_type == 0:
                         tmp_array = pd.Series(tmp_array).astype("boolean")
 
                     data_array = tmp_array
@@ -283,48 +363,6 @@ class IoTDBRpcDataSet(object):
         df = pd.DataFrame(result)
         return df
 
-    def construct_one_row(self):
-        # simulating buffer, read 8 bytes from data set and discard first 8 
bytes which have been read.
-        self.__time_bytes = self.__query_data_set.time[:8]
-        self.__query_data_set.time = self.__query_data_set.time[8:]
-        for i in range(len(self.__query_data_set.bitmapList)):
-            bitmap_buffer = self.__query_data_set.bitmapList[i]
-
-            # another 8 new rows, should move the bitmap buffer position to 
next byte
-            if self.__rows_index % 8 == 0:
-                self.__current_bitmap[i] = bitmap_buffer[0]
-                self.__query_data_set.bitmapList[i] = bitmap_buffer[1:]
-            if not self.is_null(i, self.__rows_index):
-                value_buffer = self.__query_data_set.valueList[i]
-                data_type = self.__column_type_deduplicated_list[i]
-
-                # simulating buffer
-                if data_type == TSDataType.BOOLEAN:
-                    self.__value[i] = value_buffer[:1]
-                    self.__query_data_set.valueList[i] = value_buffer[1:]
-                elif data_type == TSDataType.INT32:
-                    self.__value[i] = value_buffer[:4]
-                    self.__query_data_set.valueList[i] = value_buffer[4:]
-                elif data_type == TSDataType.INT64:
-                    self.__value[i] = value_buffer[:8]
-                    self.__query_data_set.valueList[i] = value_buffer[8:]
-                elif data_type == TSDataType.FLOAT:
-                    self.__value[i] = value_buffer[:4]
-                    self.__query_data_set.valueList[i] = value_buffer[4:]
-                elif data_type == TSDataType.DOUBLE:
-                    self.__value[i] = value_buffer[:8]
-                    self.__query_data_set.valueList[i] = value_buffer[8:]
-                elif data_type == TSDataType.TEXT:
-                    length = int.from_bytes(
-                        value_buffer[:4], byteorder="big", signed=False
-                    )
-                    self.__value[i] = value_buffer[4 : 4 + length]
-                    self.__query_data_set.valueList[i] = value_buffer[4 + 
length :]
-                else:
-                    raise RuntimeError("unsupported data type 
{}.".format(data_type))
-        self.__rows_index += 1
-        self.__has_cached_record = True
-
     def fetch_results(self):
         self.__rows_index = 0
         request = TSFetchResultsReq(
@@ -347,36 +385,12 @@ class IoTDBRpcDataSet(object):
                 "Cannot fetch result from server, because of network 
connection: ", e
             )
 
-    def is_null(self, index, row_num):
-        bitmap = self.__current_bitmap[index]
-        shift = row_num % 8
-        return ((IoTDBRpcDataSet.FLAG >> shift) & (bitmap & 0xFF)) == 0
-
-    def is_null_by_index(self, column_index):
-        index = (
-            
self.__column_ordinal_dict[self.find_column_name_by_index(column_index)]
-            - IoTDBRpcDataSet.START_INDEX
-        )
-        # time column will never be None
-        if index < 0:
-            return True
-        return self.is_null(index, self.__rows_index - 1)
-
-    def is_null_by_name(self, column_name):
-        index = self.__column_ordinal_dict[column_name] - 
IoTDBRpcDataSet.START_INDEX
-        # time column will never be None
-        if index < 0:
-            return True
-        return self.is_null(index, self.__rows_index - 1)
-
     def find_column_name_by_index(self, column_index):
         if column_index <= 0:
             raise Exception("Column index should start from 1")
         if column_index > len(self.__column_name_list):
             raise Exception(
-                "column index {} out of range {}".format(
-                    column_index, self.__column_size
-                )
+                "column index {} out of range {}".format(column_index, 
self.column_size)
             )
         return self.__column_name_list[column_index - 1]
 
@@ -391,24 +405,3 @@ class IoTDBRpcDataSet(object):
 
     def get_column_types(self):
         return self.__column_type_list
-
-    def get_column_size(self):
-        return self.__column_size
-
-    def get_ignore_timestamp(self):
-        return self.__ignore_timestamp
-
-    def get_column_ordinal_dict(self):
-        return self.__column_ordinal_dict
-
-    def get_column_type_deduplicated_list(self):
-        return self.__column_type_deduplicated_list
-
-    def get_values(self):
-        return self.__value
-
-    def get_time_bytes(self):
-        return self.__time_bytes
-
-    def get_has_cached_record(self):
-        return self.__has_cached_record
diff --git a/iotdb-client/client-py/iotdb/utils/SessionDataSet.py 
b/iotdb-client/client-py/iotdb/utils/SessionDataSet.py
index 02eef027dfd..a7ce7dfc40c 100644
--- a/iotdb-client/client-py/iotdb/utils/SessionDataSet.py
+++ b/iotdb-client/client-py/iotdb/utils/SessionDataSet.py
@@ -16,7 +16,6 @@
 # under the License.
 #
 import logging
-import struct
 
 from iotdb.utils.Field import Field
 
@@ -55,8 +54,26 @@ class SessionDataSet(object):
             statement_id,
             session_id,
             query_data_set,
-            1024,
+            5000,
         )
+        self.column_size = self.iotdb_rpc_data_set.column_size
+        self.is_ignore_timestamp = self.iotdb_rpc_data_set.ignore_timestamp
+        self.column_names = tuple(self.iotdb_rpc_data_set.get_column_names())
+        self.column_ordinal_dict = self.iotdb_rpc_data_set.column_ordinal_dict
+        self.column_type_deduplicated_list = tuple(
+            self.iotdb_rpc_data_set.column_type_deduplicated_list
+        )
+        if self.is_ignore_timestamp:
+            self.__field_list = [
+                Field(data_type)
+                for data_type in self.iotdb_rpc_data_set.get_column_types()
+            ]
+        else:
+            self.__field_list = [
+                Field(data_type)
+                for data_type in self.iotdb_rpc_data_set.get_column_types()[1:]
+            ]
+        self.row_index = 0
 
     def __enter__(self):
         return self
@@ -80,58 +97,29 @@ class SessionDataSet(object):
         return self.iotdb_rpc_data_set.next()
 
     def next(self):
-        if not self.iotdb_rpc_data_set.get_has_cached_record():
+        if not self.iotdb_rpc_data_set.has_cached_data_frame:
             if not self.has_next():
                 return None
-        self.iotdb_rpc_data_set.has_cached_record = False
-        return self.construct_row_record_from_value_array()
-
-    def construct_row_record_from_value_array(self):
-        out_fields = []
-        for i in range(self.iotdb_rpc_data_set.get_column_size()):
-            index = i + 1
-            data_set_column_index = i + IoTDBRpcDataSet.START_INDEX
-            if self.iotdb_rpc_data_set.get_ignore_timestamp():
-                index -= 1
-                data_set_column_index -= 1
-            column_name = self.iotdb_rpc_data_set.get_column_names()[index]
-            location = (
-                self.iotdb_rpc_data_set.get_column_ordinal_dict()[column_name]
-                - IoTDBRpcDataSet.START_INDEX
-            )
-
-            if not 
self.iotdb_rpc_data_set.is_null_by_index(data_set_column_index):
-                value_bytes = self.iotdb_rpc_data_set.get_values()[location]
-                data_type = 
self.iotdb_rpc_data_set.get_column_type_deduplicated_list()[
-                    location
-                ]
-                field = Field(data_type)
-                if data_type == TSDataType.BOOLEAN:
-                    value = struct.unpack(">?", value_bytes)[0]
-                    field.set_bool_value(value)
-                elif data_type == TSDataType.INT32:
-                    value = struct.unpack(">i", value_bytes)[0]
-                    field.set_int_value(value)
-                elif data_type == TSDataType.INT64:
-                    value = struct.unpack(">q", value_bytes)[0]
-                    field.set_long_value(value)
-                elif data_type == TSDataType.FLOAT:
-                    value = struct.unpack(">f", value_bytes)[0]
-                    field.set_float_value(value)
-                elif data_type == TSDataType.DOUBLE:
-                    value = struct.unpack(">d", value_bytes)[0]
-                    field.set_double_value(value)
-                elif data_type == TSDataType.TEXT:
-                    field.set_binary_value(value_bytes)
-                else:
-                    raise RuntimeError("unsupported data type 
{}.".format(data_type))
-            else:
-                field = Field(None)
-            out_fields.append(field)
-
-        return RowRecord(
-            struct.unpack(">q", self.iotdb_rpc_data_set.get_time_bytes())[0], 
out_fields
+        return self.construct_row_record_from_data_frame()
+
+    def construct_row_record_from_data_frame(self):
+        df = self.iotdb_rpc_data_set.data_frame
+        row = df.iloc[self.row_index].to_list()
+        row_values = row[1:]
+        for i, value in enumerate(row_values):
+            self.__field_list[i].value = value
+
+        row_record = RowRecord(
+            row[0],
+            self.__field_list,
         )
+        self.row_index += 1
+        if self.row_index == len(df):
+            self.row_index = 0
+            self.iotdb_rpc_data_set.has_cached_data_frame = False
+            self.iotdb_rpc_data_set.data_frame = None
+
+        return row_record
 
     def close_operation_handle(self):
         self.iotdb_rpc_data_set.close()
diff --git a/iotdb-client/client-py/tests/test_tablet.py 
b/iotdb-client/client-py/tests/test_tablet.py
index 3e2d53543e4..8a035318d9f 100644
--- a/iotdb-client/client-py/tests/test_tablet.py
+++ b/iotdb-client/client-py/tests/test_tablet.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
+import numpy as np
 import pandas as pd
 from pandas.testing import assert_frame_equal
 
@@ -55,14 +55,16 @@ def test_tablet_insertion():
         columns = []
         for measurement in measurements_:
             columns.append("root.sg_test_01.d_01." + measurement)
-        df_input = pd.DataFrame(values_, None, columns)
-        df_input.insert(0, "Time", timestamps_)
+        df_input = pd.DataFrame(values_, columns=columns, dtype=object)
+        df_input.insert(0, "Time", np.array(timestamps_))
 
         session_data_set = session.execute_query_statement(
             "select s_01, s_02, s_03, s_04, s_05, s_06 from 
root.sg_test_01.d_01"
         )
         df_output = session_data_set.todf()
-        df_output = df_output[df_input.columns.tolist()]
+        df_output = df_output[df_input.columns.tolist()].replace(
+            {pd.NA: None, np.nan: None}
+        )
 
         session.close()
     assert_frame_equal(df_input, df_output, False)
@@ -98,14 +100,16 @@ def test_nullable_tablet_insertion():
         columns = []
         for measurement in measurements_:
             columns.append("root.sg_test_01.d_01." + measurement)
-        df_input = pd.DataFrame(values_, None, columns)
-        df_input.insert(0, "Time", timestamps_)
+        df_input = pd.DataFrame(values_, columns=columns, dtype=object)
+        df_input.insert(0, "Time", np.array(timestamps_))
 
         session_data_set = session.execute_query_statement(
             "select s_01, s_02, s_03, s_04, s_05, s_06 from 
root.sg_test_01.d_01"
         )
         df_output = session_data_set.todf()
         df_output = df_output[df_input.columns.tolist()]
-
+        df_output = df_output[df_input.columns.tolist()].replace(
+            {pd.NA: None, np.nan: None}
+        )
         session.close()
     assert_frame_equal(df_input, df_output, False)


Reply via email to