This is an automated email from the ASF dual-hosted git repository.
xingtanzjr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 1d1b52cccb8 Optimize result set iteration in Python Client (#11231)
1d1b52cccb8 is described below
commit 1d1b52cccb8945089f8801945ac5c92b22f298dc
Author: Haonan <[email protected]>
AuthorDate: Thu Oct 12 20:45:30 2023 -0500
Optimize result set iteration in Python Client (#11231)
---
iotdb-client/client-py/iotdb/utils/Field.py | 148 +++++-----
.../client-py/iotdb/utils/IoTDBRpcDataSet.py | 299 ++++++++++-----------
.../client-py/iotdb/utils/SessionDataSet.py | 90 +++----
iotdb-client/client-py/tests/test_tablet.py | 18 +-
4 files changed, 273 insertions(+), 282 deletions(-)
diff --git a/iotdb-client/client-py/iotdb/utils/Field.py
b/iotdb-client/client-py/iotdb/utils/Field.py
index 0756b1c49d6..913c7594bf5 100644
--- a/iotdb-client/client-py/iotdb/utils/Field.py
+++ b/iotdb-client/client-py/iotdb/utils/Field.py
@@ -18,20 +18,17 @@
# for package
from .IoTDBConstants import TSDataType
+import numpy as np
+import pandas as pd
class Field(object):
- def __init__(self, data_type):
+ def __init__(self, data_type, value=None):
"""
:param data_type: TSDataType
"""
self.__data_type = data_type
- self.__bool_value = None
- self.__int_value = None
- self.__long_value = None
- self.__float_value = None
- self.__double_value = None
- self.__binary_value = None
+ self.value = value
@staticmethod
def copy(field):
@@ -59,73 +56,99 @@ class Field(object):
return self.__data_type
def is_null(self):
- return self.__data_type is None
+ return self.__data_type is None or self.value is None or self.value is
pd.NA
- def set_bool_value(self, value):
- self.__bool_value = value
+ def set_bool_value(self, value: bool):
+ self.value = value
def get_bool_value(self):
if self.__data_type is None:
raise Exception("Null Field Exception!")
- return self.__bool_value
+ if (
+ self.__data_type != TSDataType.BOOLEAN
+ or self.value is None
+ or self.value is pd.NA
+ ):
+ return None
+ return self.value
- def set_int_value(self, value):
- self.__int_value = value
+ def set_int_value(self, value: int):
+ self.value = value
def get_int_value(self):
if self.__data_type is None:
raise Exception("Null Field Exception!")
- return self.__int_value
+ if (
+ self.__data_type != TSDataType.INT32
+ or self.value is None
+ or self.value is pd.NA
+ ):
+ return None
+ return np.int32(self.value)
- def set_long_value(self, value):
- self.__long_value = value
+ def set_long_value(self, value: int):
+ self.value = value
def get_long_value(self):
if self.__data_type is None:
raise Exception("Null Field Exception!")
- return self.__long_value
+ if (
+ self.__data_type != TSDataType.INT64
+ or self.value is None
+ or self.value is pd.NA
+ ):
+ return None
+ return np.int64(self.value)
- def set_float_value(self, value):
- self.__float_value = value
+ def set_float_value(self, value: float):
+ self.value = value
def get_float_value(self):
if self.__data_type is None:
raise Exception("Null Field Exception!")
- return self.__float_value
+ if (
+ self.__data_type != TSDataType.FLOAT
+ or self.value is None
+ or self.value is pd.NA
+ ):
+ return None
+ return np.float32(self.value)
- def set_double_value(self, value):
- self.__double_value = value
+ def set_double_value(self, value: float):
+ self.value = value
def get_double_value(self):
if self.__data_type is None:
raise Exception("Null Field Exception!")
- return self.__double_value
+ if (
+ self.__data_type != TSDataType.DOUBLE
+ or self.value is None
+ or self.value is pd.NA
+ ):
+ return None
+ return np.float64(self.value)
- def set_binary_value(self, value):
- self.__binary_value = value
+ def set_binary_value(self, value: bytes):
+ self.value = value
def get_binary_value(self):
if self.__data_type is None:
raise Exception("Null Field Exception!")
- return self.__binary_value
+ if (
+ self.__data_type != TSDataType.TEXT
+ or self.value is None
+ or self.value is pd.NA
+ ):
+ return None
+ return self.value
def get_string_value(self):
- if self.__data_type is None:
+ if self.__data_type is None or self.value is None or self.value is
pd.NA:
return "None"
- elif self.__data_type == TSDataType.BOOLEAN:
- return str(self.__bool_value)
- elif self.__data_type == TSDataType.INT64:
- return str(self.__long_value)
- elif self.__data_type == TSDataType.INT32:
- return str(self.__int_value)
- elif self.__data_type == TSDataType.FLOAT:
- return str(self.__float_value)
- elif self.__data_type == TSDataType.DOUBLE:
- return str(self.__double_value)
- elif self.__data_type == TSDataType.TEXT:
- return self.__binary_value.decode("utf-8")
+ elif self.__data_type == 5:
+ return self.value.decode("utf-8")
else:
- raise Exception("unsupported data type
{}".format(self.__data_type))
+ return str(self.get_object_value(self.__data_type))
def __str__(self):
return self.get_string_value()
@@ -134,22 +157,19 @@ class Field(object):
"""
:param data_type: TSDataType
"""
- if self.__data_type is None:
+ if self.__data_type is None or self.value is None or self.value is
pd.NA:
return None
- elif data_type == TSDataType.BOOLEAN:
- return self.get_bool_value()
- elif data_type == TSDataType.INT32:
- return self.get_int_value()
- elif data_type == TSDataType.INT64:
- return self.get_long_value()
- elif data_type == TSDataType.FLOAT:
- return self.get_float_value()
- elif data_type == TSDataType.DOUBLE:
- return self.get_double_value()
- elif data_type == TSDataType.TEXT:
- return self.get_binary_value()
- else:
- raise Exception("unsupported data type {}".format(data_type))
+ if data_type == 0:
+ return bool(self.value)
+ elif data_type == 1:
+ return np.int32(self.value)
+ elif data_type == 2:
+ return np.int64(self.value)
+ elif data_type == 3:
+ return np.float32(self.value)
+ elif data_type == 4:
+ return np.float64(self.value)
+ return self.value
@staticmethod
def get_field(value, data_type):
@@ -157,21 +177,7 @@ class Field(object):
:param value: field value corresponding to the data type
:param data_type: TSDataType
"""
- if value is None:
+ if value is None or value is pd.NA:
return None
- field = Field(data_type)
- if data_type == TSDataType.BOOLEAN:
- field.set_bool_value(value)
- elif data_type == TSDataType.INT32:
- field.set_int_value(value)
- elif data_type == TSDataType.INT64:
- field.set_long_value(value)
- elif data_type == TSDataType.FLOAT:
- field.set_float_value(value)
- elif data_type == TSDataType.DOUBLE:
- field.set_double_value(value)
- elif data_type == TSDataType.TEXT:
- field.set_binary_value(value)
- else:
- raise Exception("unsupported data type {}".format(data_type))
+ field = Field(data_type, value)
return field
diff --git a/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py
b/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py
index 564a0377341..53ef09b8e7a 100644
--- a/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py
+++ b/iotdb-client/client-py/iotdb/utils/IoTDBRpcDataSet.py
@@ -29,11 +29,14 @@ from iotdb.utils.IoTDBConstants import TSDataType
logger = logging.getLogger("IoTDB")
+def _to_bitbuffer(b):
+ return bytes("{:0{}b}".format(int(binascii.hexlify(b), 16), 8 * len(b)),
"utf-8")
+
+
class IoTDBRpcDataSet(object):
TIMESTAMP_STR = "Time"
# VALUE_IS_NULL = "The value got by %s (column name) is NULL."
START_INDEX = 2
- FLAG = 0x80
def __init__(
self,
@@ -51,62 +54,55 @@ class IoTDBRpcDataSet(object):
):
self.__statement_id = statement_id
self.__session_id = session_id
- self.__ignore_timestamp = ignore_timestamp
+ self.ignore_timestamp = ignore_timestamp
self.__sql = sql
self.__query_id = query_id
self.__client = client
self.__fetch_size = fetch_size
- self.__column_size = len(column_name_list)
+ self.column_size = len(column_name_list)
self.__default_time_out = 1000
self.__column_name_list = []
self.__column_type_list = []
- self.__column_ordinal_dict = {}
+ self.column_ordinal_dict = {}
if not ignore_timestamp:
self.__column_name_list.append(IoTDBRpcDataSet.TIMESTAMP_STR)
self.__column_type_list.append(TSDataType.INT64)
- self.__column_ordinal_dict[IoTDBRpcDataSet.TIMESTAMP_STR] = 1
+ self.column_ordinal_dict[IoTDBRpcDataSet.TIMESTAMP_STR] = 1
if column_name_index is not None:
- self.__column_type_deduplicated_list = [
+ self.column_type_deduplicated_list = [
None for _ in range(len(column_name_index))
]
- for i in range(len(column_name_list)):
+ for i in range(self.column_size):
name = column_name_list[i]
self.__column_name_list.append(name)
self.__column_type_list.append(TSDataType[column_type_list[i]])
- if name not in self.__column_ordinal_dict:
+ if name not in self.column_ordinal_dict:
index = column_name_index[name]
- self.__column_ordinal_dict[name] = (
- index + IoTDBRpcDataSet.START_INDEX
- )
- self.__column_type_deduplicated_list[index] = TSDataType[
+ self.column_ordinal_dict[name] = index +
IoTDBRpcDataSet.START_INDEX
+ self.column_type_deduplicated_list[index] = TSDataType[
column_type_list[i]
]
else:
index = IoTDBRpcDataSet.START_INDEX
- self.__column_type_deduplicated_list = []
+ self.column_type_deduplicated_list = []
for i in range(len(column_name_list)):
name = column_name_list[i]
self.__column_name_list.append(name)
self.__column_type_list.append(TSDataType[column_type_list[i]])
- if name not in self.__column_ordinal_dict:
- self.__column_ordinal_dict[name] = index
+ if name not in self.column_ordinal_dict:
+ self.column_ordinal_dict[name] = index
index += 1
- self.__column_type_deduplicated_list.append(
+ self.column_type_deduplicated_list.append(
TSDataType[column_type_list[i]]
)
-
- self.__time_bytes = bytes(0)
- self.__current_bitmap = [
- bytes(0) for _ in range(len(self.__column_type_deduplicated_list))
- ]
- self.__value = [None for _ in
range(len(self.__column_type_deduplicated_list))]
self.__query_data_set = query_data_set
self.__is_closed = False
self.__empty_resultSet = False
- self.__has_cached_record = False
self.__rows_index = 0
+ self.has_cached_data_frame = False
+ self.data_frame = None
def close(self):
if self.__is_closed:
@@ -132,23 +128,120 @@ class IoTDBRpcDataSet(object):
self.__client = None
def next(self):
- if self.has_cached_result():
- self.construct_one_row()
+ if not self.has_cached_data_frame:
+ self.construct_one_data_frame()
+ if self.has_cached_data_frame:
return True
if self.__empty_resultSet:
return False
if self.fetch_results():
- self.construct_one_row()
+ self.construct_one_data_frame()
return True
return False
- def has_cached_result(self):
- return (self.__query_data_set is not None) and (
- len(self.__query_data_set.time) != 0
+ def construct_one_data_frame(self):
+ if (
+ self.has_cached_data_frame
+ or self.__query_data_set is None
+ or len(self.__query_data_set.time) == 0
+ ):
+ return
+ result = {}
+ time_array = np.frombuffer(
+ self.__query_data_set.time, np.dtype(np.longlong).newbyteorder(">")
)
+ if time_array.dtype.byteorder == ">":
+ time_array = time_array.byteswap().newbyteorder("<")
+ result[0] = time_array
+ total_length = len(time_array)
+ for i in range(self.column_size):
+ if self.ignore_timestamp is True:
+ column_name = self.__column_name_list[i]
+ else:
+ column_name = self.__column_name_list[i + 1]
+
+ location = (
+ self.column_ordinal_dict[column_name] -
IoTDBRpcDataSet.START_INDEX
+ )
+ if location < 0:
+ continue
+ data_type = self.column_type_deduplicated_list[location]
+ value_buffer = self.__query_data_set.valueList[location]
+ value_buffer_len = len(value_buffer)
+ if data_type == 4:
+ data_array = np.frombuffer(
+ value_buffer, np.dtype(np.double).newbyteorder(">")
+ )
+ elif data_type == 3:
+ data_array = np.frombuffer(
+ value_buffer, np.dtype(np.float32).newbyteorder(">")
+ )
+ elif data_type == 0:
+ data_array = np.frombuffer(value_buffer, np.dtype("?"))
+ elif data_type == 1:
+ data_array = np.frombuffer(
+ value_buffer, np.dtype(np.int32).newbyteorder(">")
+ )
+ elif data_type == 2:
+ data_array = np.frombuffer(
+ value_buffer, np.dtype(np.int64).newbyteorder(">")
+ )
+ elif data_type == 5:
+ j = 0
+ offset = 0
+ data_array = []
+ while offset < value_buffer_len:
+ length = int.from_bytes(
+ value_buffer[offset : offset + 4],
+ byteorder="big",
+ signed=False,
+ )
+ offset += 4
+ value = bytes(value_buffer[offset : offset + length])
+ data_array.append(value)
+ j += 1
+ offset += length
+ data_array = np.array(data_array, dtype=object)
+ else:
+ raise RuntimeError("unsupported data type
{}.".format(data_type))
+ if data_array.dtype.byteorder == ">":
+ data_array = data_array.byteswap().newbyteorder("<")
+ # self.__query_data_set.valueList[location] = None
+ if len(data_array) < total_length:
+ # INT32 or INT64 or boolean
+ if data_type == 0 or data_type == 1 or data_type == 2:
+ tmp_array = np.full(total_length, np.nan, np.float32)
+ else:
+ tmp_array = np.full(total_length, None, dtype=object)
+
+ bitmap_buffer = self.__query_data_set.bitmapList[location]
+ buffer = _to_bitbuffer(bitmap_buffer)
+ bit_mask = (np.frombuffer(buffer, "u1") -
ord("0")).astype(bool)
+ if len(bit_mask) != total_length:
+ bit_mask = bit_mask[:total_length]
+ tmp_array[bit_mask] = data_array
+
+ if data_type == 1:
+ tmp_array = pd.Series(tmp_array, dtype="Int32")
+ elif data_type == 2:
+ tmp_array = pd.Series(tmp_array, dtype="Int64")
+ elif data_type == 0:
+ tmp_array = pd.Series(tmp_array, dtype="boolean")
+ data_array = tmp_array
+
+ result[i + 1] = data_array
+ self.__query_data_set = None
+ self.data_frame = pd.DataFrame(result, dtype=object)
+ if not self.data_frame.empty:
+ self.has_cached_data_frame = True
+
+ def has_cached_result(self):
+ return self.has_cached_data_frame
def _has_next_result_set(self):
- if self.has_cached_result():
+ if (self.__query_data_set is not None) and (
+ len(self.__query_data_set.time) != 0
+ ):
return True
if self.__empty_resultSet:
return False
@@ -156,11 +249,6 @@ class IoTDBRpcDataSet(object):
return True
return False
- def _to_bitbuffer(self, b):
- return bytes(
- "{:0{}b}".format(int(binascii.hexlify(b), 16), 8 * len(b)), "utf-8"
- )
-
def resultset_to_pandas(self):
result = {}
for column_name in self.__column_name_list:
@@ -171,51 +259,45 @@ class IoTDBRpcDataSet(object):
)
if time_array.dtype.byteorder == ">":
time_array = time_array.byteswap().newbyteorder("<")
- if (
- self.get_ignore_timestamp() is None
- or self.get_ignore_timestamp() is False
- ):
+ if self.ignore_timestamp is None or self.ignore_timestamp is False:
result[IoTDBRpcDataSet.TIMESTAMP_STR].append(time_array)
self.__query_data_set.time = []
total_length = len(time_array)
for i in range(len(self.__query_data_set.bitmapList)):
- if self.get_ignore_timestamp() is True:
- column_name = self.get_column_names()[i]
+ if self.ignore_timestamp is True:
+ column_name = self.__column_name_list[i]
else:
- column_name = self.get_column_names()[i + 1]
+ column_name = self.__column_name_list[i + 1]
location = (
- self.__column_ordinal_dict[column_name]
- - IoTDBRpcDataSet.START_INDEX
+ self.column_ordinal_dict[column_name] -
IoTDBRpcDataSet.START_INDEX
)
if location < 0:
continue
- data_type = self.__column_type_deduplicated_list[location]
+ data_type = self.column_type_deduplicated_list[location]
value_buffer = self.__query_data_set.valueList[location]
value_buffer_len = len(value_buffer)
-
- data_array = None
- if data_type == TSDataType.DOUBLE:
+ if data_type == 4:
data_array = np.frombuffer(
value_buffer, np.dtype(np.double).newbyteorder(">")
)
- elif data_type == TSDataType.FLOAT:
+ elif data_type == 3:
data_array = np.frombuffer(
value_buffer, np.dtype(np.float32).newbyteorder(">")
)
- elif data_type == TSDataType.BOOLEAN:
+ elif data_type == 0:
data_array = np.frombuffer(value_buffer, np.dtype("?"))
- elif data_type == TSDataType.INT32:
+ elif data_type == 1:
data_array = np.frombuffer(
value_buffer, np.dtype(np.int32).newbyteorder(">")
)
- elif data_type == TSDataType.INT64:
+ elif data_type == 2:
data_array = np.frombuffer(
value_buffer, np.dtype(np.int64).newbyteorder(">")
)
- elif data_type == TSDataType.TEXT:
+ elif data_type == 5:
j = 0
offset = 0
data_array = []
@@ -226,7 +308,7 @@ class IoTDBRpcDataSet(object):
signed=False,
)
offset += 4
- value_bytes = value_buffer[offset : offset + length]
+ value_bytes = bytes(value_buffer[offset : offset +
length])
value = value_bytes.decode("utf-8")
data_array.append(value)
j += 1
@@ -237,31 +319,29 @@ class IoTDBRpcDataSet(object):
if data_array.dtype.byteorder == ">":
data_array = data_array.byteswap().newbyteorder("<")
self.__query_data_set.valueList[location] = None
-
+ tmp_array = []
if len(data_array) < total_length:
- if data_type == TSDataType.INT32 or data_type ==
TSDataType.INT64:
+ if data_type == 1 or data_type == 2:
tmp_array = np.full(total_length, np.nan, np.float32)
- elif (
- data_type == TSDataType.FLOAT or data_type ==
TSDataType.DOUBLE
- ):
+ elif data_type == 3 or data_type == 4:
tmp_array = np.full(total_length, np.nan,
data_array.dtype)
- elif data_type == TSDataType.BOOLEAN:
+ elif data_type == 0:
tmp_array = np.full(total_length, np.nan, np.float32)
- elif data_type == TSDataType.TEXT:
+ elif data_type == 5:
tmp_array = np.full(total_length, None,
dtype=data_array.dtype)
bitmap_buffer = self.__query_data_set.bitmapList[location]
- buffer = self._to_bitbuffer(bitmap_buffer)
+ buffer = _to_bitbuffer(bitmap_buffer)
bit_mask = (np.frombuffer(buffer, "u1") -
ord("0")).astype(bool)
if len(bit_mask) != total_length:
bit_mask = bit_mask[:total_length]
tmp_array[bit_mask] = data_array
- if data_type == TSDataType.INT32:
+ if data_type == 1:
tmp_array = pd.Series(tmp_array).astype("Int32")
- elif data_type == TSDataType.INT64:
+ elif data_type == 2:
tmp_array = pd.Series(tmp_array).astype("Int64")
- elif data_type == TSDataType.BOOLEAN:
+ elif data_type == 0:
tmp_array = pd.Series(tmp_array).astype("boolean")
data_array = tmp_array
@@ -283,48 +363,6 @@ class IoTDBRpcDataSet(object):
df = pd.DataFrame(result)
return df
- def construct_one_row(self):
- # simulating buffer, read 8 bytes from data set and discard first 8
bytes which have been read.
- self.__time_bytes = self.__query_data_set.time[:8]
- self.__query_data_set.time = self.__query_data_set.time[8:]
- for i in range(len(self.__query_data_set.bitmapList)):
- bitmap_buffer = self.__query_data_set.bitmapList[i]
-
- # another 8 new rows, should move the bitmap buffer position to
next byte
- if self.__rows_index % 8 == 0:
- self.__current_bitmap[i] = bitmap_buffer[0]
- self.__query_data_set.bitmapList[i] = bitmap_buffer[1:]
- if not self.is_null(i, self.__rows_index):
- value_buffer = self.__query_data_set.valueList[i]
- data_type = self.__column_type_deduplicated_list[i]
-
- # simulating buffer
- if data_type == TSDataType.BOOLEAN:
- self.__value[i] = value_buffer[:1]
- self.__query_data_set.valueList[i] = value_buffer[1:]
- elif data_type == TSDataType.INT32:
- self.__value[i] = value_buffer[:4]
- self.__query_data_set.valueList[i] = value_buffer[4:]
- elif data_type == TSDataType.INT64:
- self.__value[i] = value_buffer[:8]
- self.__query_data_set.valueList[i] = value_buffer[8:]
- elif data_type == TSDataType.FLOAT:
- self.__value[i] = value_buffer[:4]
- self.__query_data_set.valueList[i] = value_buffer[4:]
- elif data_type == TSDataType.DOUBLE:
- self.__value[i] = value_buffer[:8]
- self.__query_data_set.valueList[i] = value_buffer[8:]
- elif data_type == TSDataType.TEXT:
- length = int.from_bytes(
- value_buffer[:4], byteorder="big", signed=False
- )
- self.__value[i] = value_buffer[4 : 4 + length]
- self.__query_data_set.valueList[i] = value_buffer[4 +
length :]
- else:
- raise RuntimeError("unsupported data type
{}.".format(data_type))
- self.__rows_index += 1
- self.__has_cached_record = True
-
def fetch_results(self):
self.__rows_index = 0
request = TSFetchResultsReq(
@@ -347,36 +385,12 @@ class IoTDBRpcDataSet(object):
"Cannot fetch result from server, because of network
connection: ", e
)
- def is_null(self, index, row_num):
- bitmap = self.__current_bitmap[index]
- shift = row_num % 8
- return ((IoTDBRpcDataSet.FLAG >> shift) & (bitmap & 0xFF)) == 0
-
- def is_null_by_index(self, column_index):
- index = (
-
self.__column_ordinal_dict[self.find_column_name_by_index(column_index)]
- - IoTDBRpcDataSet.START_INDEX
- )
- # time column will never be None
- if index < 0:
- return True
- return self.is_null(index, self.__rows_index - 1)
-
- def is_null_by_name(self, column_name):
- index = self.__column_ordinal_dict[column_name] -
IoTDBRpcDataSet.START_INDEX
- # time column will never be None
- if index < 0:
- return True
- return self.is_null(index, self.__rows_index - 1)
-
def find_column_name_by_index(self, column_index):
if column_index <= 0:
raise Exception("Column index should start from 1")
if column_index > len(self.__column_name_list):
raise Exception(
- "column index {} out of range {}".format(
- column_index, self.__column_size
- )
+ "column index {} out of range {}".format(column_index,
self.column_size)
)
return self.__column_name_list[column_index - 1]
@@ -391,24 +405,3 @@ class IoTDBRpcDataSet(object):
def get_column_types(self):
return self.__column_type_list
-
- def get_column_size(self):
- return self.__column_size
-
- def get_ignore_timestamp(self):
- return self.__ignore_timestamp
-
- def get_column_ordinal_dict(self):
- return self.__column_ordinal_dict
-
- def get_column_type_deduplicated_list(self):
- return self.__column_type_deduplicated_list
-
- def get_values(self):
- return self.__value
-
- def get_time_bytes(self):
- return self.__time_bytes
-
- def get_has_cached_record(self):
- return self.__has_cached_record
diff --git a/iotdb-client/client-py/iotdb/utils/SessionDataSet.py
b/iotdb-client/client-py/iotdb/utils/SessionDataSet.py
index 02eef027dfd..a7ce7dfc40c 100644
--- a/iotdb-client/client-py/iotdb/utils/SessionDataSet.py
+++ b/iotdb-client/client-py/iotdb/utils/SessionDataSet.py
@@ -16,7 +16,6 @@
# under the License.
#
import logging
-import struct
from iotdb.utils.Field import Field
@@ -55,8 +54,26 @@ class SessionDataSet(object):
statement_id,
session_id,
query_data_set,
- 1024,
+ 5000,
)
+ self.column_size = self.iotdb_rpc_data_set.column_size
+ self.is_ignore_timestamp = self.iotdb_rpc_data_set.ignore_timestamp
+ self.column_names = tuple(self.iotdb_rpc_data_set.get_column_names())
+ self.column_ordinal_dict = self.iotdb_rpc_data_set.column_ordinal_dict
+ self.column_type_deduplicated_list = tuple(
+ self.iotdb_rpc_data_set.column_type_deduplicated_list
+ )
+ if self.is_ignore_timestamp:
+ self.__field_list = [
+ Field(data_type)
+ for data_type in self.iotdb_rpc_data_set.get_column_types()
+ ]
+ else:
+ self.__field_list = [
+ Field(data_type)
+ for data_type in self.iotdb_rpc_data_set.get_column_types()[1:]
+ ]
+ self.row_index = 0
def __enter__(self):
return self
@@ -80,58 +97,29 @@ class SessionDataSet(object):
return self.iotdb_rpc_data_set.next()
def next(self):
- if not self.iotdb_rpc_data_set.get_has_cached_record():
+ if not self.iotdb_rpc_data_set.has_cached_data_frame:
if not self.has_next():
return None
- self.iotdb_rpc_data_set.has_cached_record = False
- return self.construct_row_record_from_value_array()
-
- def construct_row_record_from_value_array(self):
- out_fields = []
- for i in range(self.iotdb_rpc_data_set.get_column_size()):
- index = i + 1
- data_set_column_index = i + IoTDBRpcDataSet.START_INDEX
- if self.iotdb_rpc_data_set.get_ignore_timestamp():
- index -= 1
- data_set_column_index -= 1
- column_name = self.iotdb_rpc_data_set.get_column_names()[index]
- location = (
- self.iotdb_rpc_data_set.get_column_ordinal_dict()[column_name]
- - IoTDBRpcDataSet.START_INDEX
- )
-
- if not
self.iotdb_rpc_data_set.is_null_by_index(data_set_column_index):
- value_bytes = self.iotdb_rpc_data_set.get_values()[location]
- data_type =
self.iotdb_rpc_data_set.get_column_type_deduplicated_list()[
- location
- ]
- field = Field(data_type)
- if data_type == TSDataType.BOOLEAN:
- value = struct.unpack(">?", value_bytes)[0]
- field.set_bool_value(value)
- elif data_type == TSDataType.INT32:
- value = struct.unpack(">i", value_bytes)[0]
- field.set_int_value(value)
- elif data_type == TSDataType.INT64:
- value = struct.unpack(">q", value_bytes)[0]
- field.set_long_value(value)
- elif data_type == TSDataType.FLOAT:
- value = struct.unpack(">f", value_bytes)[0]
- field.set_float_value(value)
- elif data_type == TSDataType.DOUBLE:
- value = struct.unpack(">d", value_bytes)[0]
- field.set_double_value(value)
- elif data_type == TSDataType.TEXT:
- field.set_binary_value(value_bytes)
- else:
- raise RuntimeError("unsupported data type
{}.".format(data_type))
- else:
- field = Field(None)
- out_fields.append(field)
-
- return RowRecord(
- struct.unpack(">q", self.iotdb_rpc_data_set.get_time_bytes())[0],
out_fields
+ return self.construct_row_record_from_data_frame()
+
+ def construct_row_record_from_data_frame(self):
+ df = self.iotdb_rpc_data_set.data_frame
+ row = df.iloc[self.row_index].to_list()
+ row_values = row[1:]
+ for i, value in enumerate(row_values):
+ self.__field_list[i].value = value
+
+ row_record = RowRecord(
+ row[0],
+ self.__field_list,
)
+ self.row_index += 1
+ if self.row_index == len(df):
+ self.row_index = 0
+ self.iotdb_rpc_data_set.has_cached_data_frame = False
+ self.iotdb_rpc_data_set.data_frame = None
+
+ return row_record
def close_operation_handle(self):
self.iotdb_rpc_data_set.close()
diff --git a/iotdb-client/client-py/tests/test_tablet.py
b/iotdb-client/client-py/tests/test_tablet.py
index 3e2d53543e4..8a035318d9f 100644
--- a/iotdb-client/client-py/tests/test_tablet.py
+++ b/iotdb-client/client-py/tests/test_tablet.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
#
-
+import numpy as np
import pandas as pd
from pandas.testing import assert_frame_equal
@@ -55,14 +55,16 @@ def test_tablet_insertion():
columns = []
for measurement in measurements_:
columns.append("root.sg_test_01.d_01." + measurement)
- df_input = pd.DataFrame(values_, None, columns)
- df_input.insert(0, "Time", timestamps_)
+ df_input = pd.DataFrame(values_, columns=columns, dtype=object)
+ df_input.insert(0, "Time", np.array(timestamps_))
session_data_set = session.execute_query_statement(
"select s_01, s_02, s_03, s_04, s_05, s_06 from
root.sg_test_01.d_01"
)
df_output = session_data_set.todf()
- df_output = df_output[df_input.columns.tolist()]
+ df_output = df_output[df_input.columns.tolist()].replace(
+ {pd.NA: None, np.nan: None}
+ )
session.close()
assert_frame_equal(df_input, df_output, False)
@@ -98,14 +100,16 @@ def test_nullable_tablet_insertion():
columns = []
for measurement in measurements_:
columns.append("root.sg_test_01.d_01." + measurement)
- df_input = pd.DataFrame(values_, None, columns)
- df_input.insert(0, "Time", timestamps_)
+ df_input = pd.DataFrame(values_, columns=columns, dtype=object)
+ df_input.insert(0, "Time", np.array(timestamps_))
session_data_set = session.execute_query_statement(
"select s_01, s_02, s_03, s_04, s_05, s_06 from
root.sg_test_01.d_01"
)
df_output = session_data_set.todf()
df_output = df_output[df_input.columns.tolist()]
-
+ df_output = df_output[df_input.columns.tolist()].replace(
+ {pd.NA: None, np.nan: None}
+ )
session.close()
assert_frame_equal(df_input, df_output, False)