This is an automated email from the ASF dual-hosted git repository.

rong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 1eb9e0e  Refactor the todf() function of client-py to improve 
performance (#4242)
1eb9e0e is described below

commit 1eb9e0e24e5829d8f54c5f8573d3ae41bab3f861
Author: Wei Fu <[email protected]>
AuthorDate: Fri Nov 19 18:55:41 2021 +0800

    Refactor the todf() function of client-py to improve performance (#4242)
    
    * Refactor the todf() function of client_py to improve performance.
    
    * When the data type is INT32, INT64, and BOOLEAN, use pd.NA instead of 
null values.
    
    * Add test case of todf() function.
    
    * [bugfix] When the returned data is all empty, the problem of constructing 
a dataframe exception.
    
    * black . && flake8 .
    
    Co-authored-by: wei.fu <[email protected]>
    Co-authored-by: Steve Yurong Su <[email protected]>
---
 client-py/SessionExample.py                      |   4 +-
 client-py/iotdb/utils/BitMap.py                  |   3 +-
 client-py/iotdb/utils/IoTDBConstants.py          |   1 +
 client-py/iotdb/utils/IoTDBRpcDataSet.py         | 138 ++++++++++++++-
 client-py/iotdb/utils/SessionDataSet.py          |  26 +--
 client-py/iotdb/utils/Tablet.py                  |  10 +-
 client-py/tests/tablet_performance_comparison.py | 129 ++++++++++----
 client-py/tests/test_dataframe.py                |  29 ++-
 client-py/tests/test_todf.py                     | 216 +++++++++++++++++++++++
 9 files changed, 482 insertions(+), 74 deletions(-)

diff --git a/client-py/SessionExample.py b/client-py/SessionExample.py
index 882bcfe..722c2b7 100644
--- a/client-py/SessionExample.py
+++ b/client-py/SessionExample.py
@@ -178,7 +178,9 @@ session.execute_non_query_statement(
 )
 
 # execute sql query statement
-with session.execute_query_statement("select * from root.sg_test_01.d_01") as 
session_data_set:
+with session.execute_query_statement(
+    "select * from root.sg_test_01.d_01"
+) as session_data_set:
     session_data_set.set_fetch_size(1024)
     while session_data_set.has_next():
         print(session_data_set.next())
diff --git a/client-py/iotdb/utils/BitMap.py b/client-py/iotdb/utils/BitMap.py
index 217349f..7b171f6 100644
--- a/client-py/iotdb/utils/BitMap.py
+++ b/client-py/iotdb/utils/BitMap.py
@@ -16,13 +16,14 @@
 # under the License.
 #
 
+
 class BitMap(object):
     BIT_UTIL = [1, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7]
 
     def __init__(self, size):
         self.__size = size
         self.bits = []
-        for i in range (size // 8 + 1):
+        for i in range(size // 8 + 1):
             self.bits.append(0)
 
     def mark(self, position):
diff --git a/client-py/iotdb/utils/IoTDBConstants.py 
b/client-py/iotdb/utils/IoTDBConstants.py
index ff92773..5922dc3 100644
--- a/client-py/iotdb/utils/IoTDBConstants.py
+++ b/client-py/iotdb/utils/IoTDBConstants.py
@@ -41,6 +41,7 @@ class TSEncoding(Enum):
     REGULAR = 7
     GORILLA = 8
 
+
 @unique
 class Compressor(Enum):
     UNCOMPRESSED = 0
diff --git a/client-py/iotdb/utils/IoTDBRpcDataSet.py 
b/client-py/iotdb/utils/IoTDBRpcDataSet.py
index 6520a04..83468ad 100644
--- a/client-py/iotdb/utils/IoTDBRpcDataSet.py
+++ b/client-py/iotdb/utils/IoTDBRpcDataSet.py
@@ -17,8 +17,11 @@
 #
 
 # for package
+import binascii
 import logging
 
+import numpy as np
+import pandas as pd
 from thrift.transport import TTransport
 from iotdb.thrift.rpc.TSIService import TSFetchResultsReq, TSCloseOperationReq
 from iotdb.utils.IoTDBConstants import TSDataType
@@ -111,7 +114,9 @@ class IoTDBRpcDataSet(object):
         if self.__client is not None:
             try:
                 status = self.__client.closeOperation(
-                    TSCloseOperationReq(self.__session_id, self.__query_id, 
self.__statement_id)
+                    TSCloseOperationReq(
+                        self.__session_id, self.__query_id, self.__statement_id
+                    )
                 )
                 logger.debug(
                     "close session {}, message: {}".format(
@@ -142,6 +147,137 @@ class IoTDBRpcDataSet(object):
             len(self.__query_data_set.time) != 0
         )
 
+    def _has_next_result_set(self):
+        if self.has_cached_result():
+            return True
+        if self.__empty_resultSet:
+            return False
+        if self.fetch_results():
+            return True
+        return False
+
+    def _to_bitstring(self, b):
+        return "{:0{}b}".format(int(binascii.hexlify(b), 16), 8 * len(b))
+
+    def resultset_to_pandas(self):
+        result = {}
+        for column_name in self.__column_name_list:
+            result[column_name] = None
+        while self._has_next_result_set():
+            time_array = np.frombuffer(
+                self.__query_data_set.time, 
np.dtype(np.longlong).newbyteorder(">")
+            )
+            if time_array.dtype.byteorder == ">":
+                time_array = time_array.byteswap().newbyteorder("<")
+            if (
+                self.get_ignore_timestamp() is None
+                or self.get_ignore_timestamp() is False
+            ):
+                if result[IoTDBRpcDataSet.TIMESTAMP_STR] is None:
+                    result[IoTDBRpcDataSet.TIMESTAMP_STR] = time_array
+                else:
+                    result[IoTDBRpcDataSet.TIMESTAMP_STR] = np.concatenate(
+                        (result[IoTDBRpcDataSet.TIMESTAMP_STR], time_array), 
axis=0
+                    )
+            self.__query_data_set.time = []
+            total_length = len(time_array)
+
+            for i in range(len(self.__query_data_set.bitmapList)):
+                if self.get_ignore_timestamp() is True:
+                    column_name = self.get_column_names()[i]
+                else:
+                    column_name = self.get_column_names()[i + 1]
+
+                location = (
+                    self.__column_ordinal_dict[column_name]
+                    - IoTDBRpcDataSet.START_INDEX
+                )
+                if location < 0:
+                    continue
+                data_type = self.__column_type_deduplicated_list[location]
+                value_buffer = self.__query_data_set.valueList[location]
+                value_buffer_len = len(value_buffer)
+
+                data_array = None
+                if data_type == TSDataType.DOUBLE:
+                    data_array = np.frombuffer(
+                        value_buffer, np.dtype(np.double).newbyteorder(">")
+                    )
+                elif data_type == TSDataType.FLOAT:
+                    data_array = np.frombuffer(
+                        value_buffer, np.dtype(np.float32).newbyteorder(">")
+                    )
+                elif data_type == TSDataType.BOOLEAN:
+                    data_array = np.frombuffer(value_buffer, np.dtype("?"))
+                elif data_type == TSDataType.INT32:
+                    data_array = np.frombuffer(
+                        value_buffer, np.dtype(np.int32).newbyteorder(">")
+                    )
+                elif data_type == TSDataType.INT64:
+                    data_array = np.frombuffer(
+                        value_buffer, np.dtype(np.int64).newbyteorder(">")
+                    )
+                elif data_type == TSDataType.TEXT:
+                    j = 0
+                    offset = 0
+                    data_array = []
+                    while offset < value_buffer_len:
+                        length = int.from_bytes(
+                            value_buffer[offset : offset + 4],
+                            byteorder="big",
+                            signed=False,
+                        )
+                        offset += 4
+                        value_bytes = value_buffer[offset : offset + length]
+                        value = value_bytes.decode("utf-8")
+                        data_array.append(value)
+                        j += 1
+                        offset += length
+                    data_array = np.array(data_array, dtype=np.object)
+                else:
+                    raise RuntimeError("unsupported data type 
{}.".format(data_type))
+                if data_array.dtype.byteorder == ">":
+                    data_array = data_array.byteswap().newbyteorder("<")
+                self.__query_data_set.valueList[location] = None
+
+                if len(data_array) < total_length:
+                    if data_type == TSDataType.INT32 or data_type == 
TSDataType.INT64:
+                        tmp_array = np.full(total_length, np.nan, np.float32)
+                        if data_array.dtype == np.int32:
+                            tmp_array = pd.Series(tmp_array).astype("Int32")
+                        else:
+                            tmp_array = pd.Series(tmp_array).astype("Int64")
+                    elif (
+                        data_type == TSDataType.FLOAT or data_type == 
TSDataType.DOUBLE
+                    ):
+                        tmp_array = np.full(total_length, np.nan, 
data_array.dtype)
+                    elif data_type == TSDataType.BOOLEAN:
+                        tmp_array = np.full(total_length, np.nan, np.float32)
+                        tmp_array = pd.Series(tmp_array).astype("boolean")
+                    elif data_type == TSDataType.TEXT:
+                        tmp_array = np.full(total_length, None, 
dtype=data_array.dtype)
+                    bitmap_buffer = self.__query_data_set.bitmapList[location]
+                    bitmap_str = self._to_bitstring(bitmap_buffer)
+                    j = 0
+                    for index in range(total_length):
+                        if bitmap_str[index] == "1":
+                            tmp_array[index] = data_array[j]
+                            j += 1
+                    data_array = tmp_array
+
+                if result[column_name] is None:
+                    result[column_name] = data_array
+                else:
+                    result[column_name] = np.concatenate(
+                        (result[column_name], data_array), axis=0
+                    )
+        for k, v in result.items():
+            if v is None:
+                result[k] = []
+
+        df = pd.DataFrame(result)
+        return df
+
     def construct_one_row(self):
         # simulating buffer, read 8 bytes from data set and discard first 8 
bytes which have been read.
         self.__time_bytes = self.__query_data_set.time[:8]
diff --git a/client-py/iotdb/utils/SessionDataSet.py 
b/client-py/iotdb/utils/SessionDataSet.py
index 18c2c3c..02eef02 100644
--- a/client-py/iotdb/utils/SessionDataSet.py
+++ b/client-py/iotdb/utils/SessionDataSet.py
@@ -147,31 +147,7 @@ def resultset_to_pandas(result_set: SessionDataSet) -> 
pd.DataFrame:
     :param result_set:
     :return:
     """
-    # get column names and fields
-    column_names = result_set.get_column_names()
-
-    value_dict = {}
-
-    if "Time" in column_names:
-        offset = 1
-    else:
-        offset = 0
-
-    for i in range(len(column_names)):
-        value_dict[column_names[i]] = []
-
-    while result_set.has_next():
-        record = result_set.next()
-
-        if "Time" in column_names:
-            value_dict["Time"].append(record.get_timestamp())
-
-        for col in range(len(record.get_fields())):
-            field: Field = record.get_fields()[col]
-
-            value_dict[column_names[col + 
offset]].append(get_typed_point(field))
-
-    return pd.DataFrame(value_dict)
+    return result_set.iotdb_rpc_data_set.resultset_to_pandas()
 
 
 def get_typed_point(field: Field, none_value=None):
diff --git a/client-py/iotdb/utils/Tablet.py b/client-py/iotdb/utils/Tablet.py
index a71ad74..dc75546 100644
--- a/client-py/iotdb/utils/Tablet.py
+++ b/client-py/iotdb/utils/Tablet.py
@@ -23,7 +23,9 @@ from iotdb.utils.BitMap import BitMap
 
 
 class Tablet(object):
-    def __init__(self, device_id, measurements, data_types, values, 
timestamps, use_new=False):
+    def __init__(
+        self, device_id, measurements, data_types, values, timestamps, 
use_new=False
+    ):
         """
         creating a tablet for insertion
           for example, considering device: root.sg1.d1
@@ -176,7 +178,9 @@ class Tablet(object):
                             has_none = True
 
                 else:
-                    raise RuntimeError("Unsupported data type:" + 
str(self.__data_types[i]))
+                    raise RuntimeError(
+                        "Unsupported data type:" + str(self.__data_types[i])
+                    )
 
             if has_none:
                 for i in range(self.__column_number):
@@ -216,7 +220,7 @@ class Tablet(object):
             offset = 0
             for bs in bs_list:
                 _l = len(bs)
-                ret[offset:offset + _l] = bs
+                ret[offset : offset + _l] = bs
                 offset += _l
             return ret
 
diff --git a/client-py/tests/tablet_performance_comparison.py 
b/client-py/tests/tablet_performance_comparison.py
index 2bc8fe2..4cbbe53 100644
--- a/client-py/tests/tablet_performance_comparison.py
+++ b/client-py/tests/tablet_performance_comparison.py
@@ -27,15 +27,17 @@ from iotdb.utils.IoTDBConstants import TSDataType
 from iotdb.utils.Tablet import Tablet
 
 # the data type specified the byte order (i.e. endian)
-FORMAT_CHAR_OF_TYPES = {TSDataType.BOOLEAN: ">?",
-                        TSDataType.FLOAT: ">f4",
-                        TSDataType.DOUBLE: ">f8",
-                        TSDataType.INT32: ">i4",
-                        TSDataType.INT64: ">i8",
-                        TSDataType.TEXT: "str"}
+FORMAT_CHAR_OF_TYPES = {
+    TSDataType.BOOLEAN: ">?",
+    TSDataType.FLOAT: ">f4",
+    TSDataType.DOUBLE: ">f8",
+    TSDataType.INT32: ">i4",
+    TSDataType.INT64: ">i8",
+    TSDataType.TEXT: "str",
+}
 
 # the time column name in the csv file.
-TIME_STR = 'time'
+TIME_STR = "time"
 
 
 def load_csv_data(measure_tstype_infos: dict, data_file_name: str) -> 
pd.DataFrame:
@@ -52,7 +54,9 @@ def load_csv_data(measure_tstype_infos: dict, data_file_name: 
str) -> pd.DataFra
     return df
 
 
-def generate_csv_data(measure_tstype_infos: dict, data_file_name: str, _row: 
int, seed=0) -> None:
+def generate_csv_data(
+    measure_tstype_infos: dict, data_file_name: str, _row: int, seed=0
+) -> None:
     """
     generate csv data randomly according to given measurements and their data 
types.
     :param measure_tstype_infos: key(str): measurement name, 
value(TSDataType): measurement data type
@@ -61,29 +65,38 @@ def generate_csv_data(measure_tstype_infos: dict, 
data_file_name: str, _row: int
     :param seed: random seed
     """
     import random
+
     random.seed(seed)
 
-    CHAR_BASE = 
'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
+    CHAR_BASE = 
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
 
     def generate_data(_type: TSDataType):
         if _type == TSDataType.BOOLEAN:
             return [random.randint(0, 1) == 1 for _ in range(_row)]
         elif _type == TSDataType.INT32:
-            return [random.randint(-2 ** 31, 2 ** 31) for _ in range(_row)]
+            return [random.randint(-(2 ** 31), 2 ** 31) for _ in range(_row)]
         elif _type == TSDataType.INT64:
-            return [random.randint(-2 ** 63, 2 ** 63) for _ in range(_row)]
+            return [random.randint(-(2 ** 63), 2 ** 63) for _ in range(_row)]
         elif _type == TSDataType.FLOAT:
             return [1.5 for _ in range(_row)]
         elif _type == TSDataType.DOUBLE:
             return [0.844421 for _ in range(_row)]
         elif _type == TSDataType.TEXT:
-            return [''.join(random.choice(CHAR_BASE) for _ in range(5)) for _ 
in range(_row)]
+            return [
+                "".join(random.choice(CHAR_BASE) for _ in range(5)) for _ in 
range(_row)
+            ]
         else:
-            raise TypeError('not support type:' + str(_type))
+            raise TypeError("not support type:" + str(_type))
 
-    values = {TIME_STR: pd.Series(np.arange(_row), 
dtype=FORMAT_CHAR_OF_TYPES[TSDataType.INT64])}
+    values = {
+        TIME_STR: pd.Series(
+            np.arange(_row), dtype=FORMAT_CHAR_OF_TYPES[TSDataType.INT64]
+        )
+    }
     for column, data_type in measure_tstype_infos.items():
-        values[column] = pd.Series(generate_data(data_type), 
dtype=FORMAT_CHAR_OF_TYPES[data_type])
+        values[column] = pd.Series(
+            generate_data(data_type), dtype=FORMAT_CHAR_OF_TYPES[data_type]
+        )
 
     df = pd.DataFrame(values)
     df.to_csv(data_file_name, index=False)
@@ -119,7 +132,9 @@ def check_count(expect, _session, _sql):
             assert False, "select count return more than one line"
         line = session_data_set.next()
         actual = line.get_fields()[0].get_long_value()
-        assert expect == actual, f"count error: expect {expect} lines, actual 
{actual} lines"
+        assert (
+            expect == actual
+        ), f"count error: expect {expect} lines, actual {actual} lines"
         get_count_line = True
     if not get_count_line:
         assert False, "select count has no result"
@@ -138,13 +153,22 @@ def check_query_result(expect, _session, _sql):
     idx = 0
     while session_data_set.has_next():
         line = session_data_set.next()
-        assert str(line) == expect[idx], f"line {idx}: actual {str(line)} != 
expect ({expect[idx]})"
+        assert (
+            str(line) == expect[idx]
+        ), f"line {idx}: actual {str(line)} != expect ({expect[idx]})"
         idx += 1
     assert idx == len(expect), f"result rows: actual ({idx}) != expect 
({len(expect)})"
     session_data_set.close_operation_handle()
 
 
-def performance_test(measure_tstype_infos, data_file_name, use_new=True, 
check_result=False, row=10000, col=5000):
+def performance_test(
+    measure_tstype_infos,
+    data_file_name,
+    use_new=True,
+    check_result=False,
+    row=10000,
+    col=5000,
+):
     """
     execute tablet insert using original or new methods.
     :param measure_tstype_infos: key(str): measurement name, 
value(TSDataType): measurement data type
@@ -153,12 +177,14 @@ def performance_test(measure_tstype_infos, 
data_file_name, use_new=True, check_r
     :param row: tablet row number
     :param col: tablet column number
     """
-    print(f"Test python: use new: {use_new}, row: {row}, col: {col}. 
measurements: {measure_tstype_infos}")
+    print(
+        f"Test python: use new: {use_new}, row: {row}, col: {col}. 
measurements: {measure_tstype_infos}"
+    )
     print(f"Total points: {len(measure_tstype_infos) * row * col}")
 
     # open the session and clean data
     session = create_open_session()
-    session.execute_non_query_statement(f'delete timeseries root.*')
+    session.execute_non_query_statement("delete timeseries root.*")
 
     # test start
     st = time.perf_counter()
@@ -195,7 +221,9 @@ def performance_test(measure_tstype_infos, data_file_name, 
use_new=True, check_r
                         value_array = value_array.astype(type_char)
                 values.append(value_array)
 
-        tablet = Tablet(device_id, measurements, data_types, values, 
timestamps_, use_new=use_new)
+        tablet = Tablet(
+            device_id, measurements, data_types, values, timestamps_, 
use_new=use_new
+        )
         cost_st = time.perf_counter()
         session.insert_tablet(tablet)
         insert_cost += time.perf_counter() - cost_st
@@ -208,12 +236,14 @@ def performance_test(measure_tstype_infos, 
data_file_name, use_new=True, check_r
                 for m in measurements:
                     line.append(str(csv_data.at[t, m]))
                 expect.append("\t\t".join([v for v in line]))
-            check_query_result(expect, session, f"select 
{','.join(measurements)} from {device_id}")
+            check_query_result(
+                expect, session, f"select {','.join(measurements)} from 
{device_id}"
+            )
             print("query validation have passed")
     end = time.perf_counter()
 
     # clean data and close the session
-    session.execute_non_query_statement(f'delete timeseries root.*')
+    session.execute_non_query_statement("delete timeseries root.*")
     session.close()
 
     print("load cost: %.3f s" % load_cost)
@@ -222,27 +252,48 @@ def performance_test(measure_tstype_infos, 
data_file_name, use_new=True, check_r
     print("total cost: %.3f s" % (end - st))
 
 
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='tablet performance 
comparison')
-    parser.add_argument('--row', type=int, default=10000, help="the row number 
of the input tablet")
-    parser.add_argument('--col', type=int, default=5000, help="the column 
number of the input tablet")
-    parser.add_argument('--check_result', '-c', action="store_true", 
help="True if check out the result")
-    parser.add_argument('--use_new', '-n', action="store_false", help="True if 
use the new tablet insert")
-    parser.add_argument('--seed', type=int, default=0, help="the random seed 
for generating csv data")
-    parser.add_argument('--data_file_name', type=str, default='sample.csv', 
help="the path of csv data")
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="tablet performance 
comparison")
+    parser.add_argument(
+        "--row", type=int, default=10000, help="the row number of the input 
tablet"
+    )
+    parser.add_argument(
+        "--col", type=int, default=5000, help="the column number of the input 
tablet"
+    )
+    parser.add_argument(
+        "--check_result", "-c", action="store_true", help="True if check out 
the result"
+    )
+    parser.add_argument(
+        "--use_new",
+        "-n",
+        action="store_false",
+        help="True if use the new tablet insert",
+    )
+    parser.add_argument(
+        "--seed", type=int, default=0, help="the random seed for generating 
csv data"
+    )
+    parser.add_argument(
+        "--data_file_name", type=str, default="sample.csv", help="the path of 
csv data"
+    )
     args = parser.parse_args()
 
     measure_tstype_infos = {
-        's0': TSDataType.BOOLEAN,
-        's1': TSDataType.FLOAT,
-        's2': TSDataType.INT32,
-        's3': TSDataType.DOUBLE,
-        's4': TSDataType.INT64,
-        's5': TSDataType.TEXT,
+        "s0": TSDataType.BOOLEAN,
+        "s1": TSDataType.FLOAT,
+        "s2": TSDataType.INT32,
+        "s3": TSDataType.DOUBLE,
+        "s4": TSDataType.INT64,
+        "s5": TSDataType.TEXT,
     }
     # if not os.path.exists(args.data_file_name):
     random.seed(a=args.seed, version=2)
     generate_csv_data(measure_tstype_infos, args.data_file_name, args.row, 
args.seed)
 
-    performance_test(measure_tstype_infos, data_file_name=args.data_file_name, 
use_new=args.use_new,
-                     check_result=args.check_result, row=args.row, 
col=args.col)
+    performance_test(
+        measure_tstype_infos,
+        data_file_name=args.data_file_name,
+        use_new=args.use_new,
+        check_result=args.check_result,
+        row=args.row,
+        col=args.col,
+    )
diff --git a/client-py/tests/test_dataframe.py 
b/client-py/tests/test_dataframe.py
index cd47b30..4cfa576 100644
--- a/client-py/tests/test_dataframe.py
+++ b/client-py/tests/test_dataframe.py
@@ -56,7 +56,28 @@ def test_non_time_query():
 
         session.close()
 
-    assert list(df.columns) == ['timeseries', 'alias', 'storage group', 
'dataType', 'encoding', 'compression', 'tags',
-                                'attributes']
-    assert_array_equal(df.values,
-                       [['root.device.pressure', None, 'root.device', 'FLOAT', 
'GORILLA', 'SNAPPY', None, None]])
+    assert list(df.columns) == [
+        "timeseries",
+        "alias",
+        "storage group",
+        "dataType",
+        "encoding",
+        "compression",
+        "tags",
+        "attributes",
+    ]
+    assert_array_equal(
+        df.values,
+        [
+            [
+                "root.device.pressure",
+                None,
+                "root.device",
+                "FLOAT",
+                "GORILLA",
+                "SNAPPY",
+                None,
+                None,
+            ]
+        ],
+    )
diff --git a/client-py/tests/test_todf.py b/client-py/tests/test_todf.py
new file mode 100644
index 0000000..73f9fa4
--- /dev/null
+++ b/client-py/tests/test_todf.py
@@ -0,0 +1,216 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import random
+
+import numpy as np
+import pandas as pd
+from pandas.testing import assert_frame_equal
+
+from iotdb.IoTDBContainer import IoTDBContainer
+from iotdb.Session import Session
+from iotdb.utils.IoTDBConstants import TSDataType, TSEncoding, Compressor
+from iotdb.utils.Tablet import Tablet
+
+device_id = "root.wt1"
+
+ts_path_lst = [
+    "root.wt1.temperature",
+    "root.wt1.windspeed",
+    "root.wt1.angle",
+    "root.wt1.altitude",
+    "root.wt1.status",
+    "root.wt1.hardware",
+]
+measurements = [
+    "temperature",
+    "windspeed",
+    "angle",
+    "altitude",
+    "status",
+    "hardware",
+]
+data_type_lst = [
+    TSDataType.FLOAT,
+    TSDataType.DOUBLE,
+    TSDataType.INT32,
+    TSDataType.INT64,
+    TSDataType.BOOLEAN,
+    TSDataType.TEXT,
+]
+
+
+def create_ts(session):
+    # setting time series.
+    encoding_lst = [TSEncoding.PLAIN for _ in range(len(data_type_lst))]
+    compressor_lst = [Compressor.SNAPPY for _ in range(len(data_type_lst))]
+    session.create_multi_time_series(
+        ts_path_lst, data_type_lst, encoding_lst, compressor_lst
+    )
+
+
+def test_simple_query():
+    with IoTDBContainer() as db:
+        db: IoTDBContainer
+        session = Session(db.get_container_host_ip(), 
db.get_exposed_port(6667))
+        session.open(False)
+
+        create_ts(session)
+
+        # insert data
+        data_nums = 100
+        data = {}
+        timestamps = np.arange(data_nums)
+        data[ts_path_lst[0]] = np.float32(np.random.rand(data_nums))
+        data[ts_path_lst[1]] = np.random.rand(data_nums)
+        data[ts_path_lst[2]] = np.random.randint(10, 100, data_nums, 
dtype="int32")
+        data[ts_path_lst[3]] = np.random.randint(10, 100, data_nums, 
dtype="int64")
+        data[ts_path_lst[4]] = np.random.choice([True, False], size=data_nums)
+        data[ts_path_lst[5]] = np.random.choice(["text1", "text2"], 
size=data_nums)
+
+        df_input = pd.DataFrame(data)
+
+        tablet = Tablet(
+            device_id, measurements, data_type_lst, df_input.values, timestamps
+        )
+        session.insert_tablet(tablet)
+
+        df_input.insert(0, "Time", timestamps)
+
+        session_data_set = session.execute_query_statement("SELECT * FROM 
root.*")
+        df_output = session_data_set.todf()
+        df_output = df_output[df_input.columns.tolist()]
+
+        session.close()
+    assert_frame_equal(df_input, df_output)
+
+
+def test_with_null_query():
+    with IoTDBContainer() as db:
+        db: IoTDBContainer
+        session = Session(db.get_container_host_ip(), 
db.get_exposed_port(6667))
+        session.open(False)
+
+        create_ts(session)
+
+        # insert data
+        data_nums = 100
+        data = {}
+        timestamps = np.arange(data_nums)
+        data[ts_path_lst[0]] = np.float32(np.random.rand(data_nums))
+        data[ts_path_lst[1]] = np.random.rand(data_nums)
+        data[ts_path_lst[2]] = np.random.randint(10, 100, data_nums, 
dtype="int32")
+        data[ts_path_lst[3]] = np.random.randint(10, 100, data_nums, 
dtype="int64")
+        data[ts_path_lst[4]] = np.random.choice([True, False], 
size=data_nums).astype(
+            "bool"
+        )
+        data[ts_path_lst[5]] = np.random.choice(
+            ["text1", "text2"], size=data_nums
+        ).astype(np.object)
+
+        data_empty = {}
+        for ts_path in ts_path_lst:
+            if data[ts_path].dtype == np.int32 or data[ts_path].dtype == 
np.int64:
+                tmp_array = np.full(data_nums, np.nan, np.float32)
+                if data[ts_path].dtype == np.int32:
+                    tmp_array = pd.Series(tmp_array).astype("Int32")
+                else:
+                    tmp_array = pd.Series(tmp_array).astype("Int64")
+            elif data[ts_path].dtype == np.float32 or data[ts_path].dtype == 
np.double:
+                tmp_array = np.full(data_nums, np.nan, data[ts_path].dtype)
+            elif data[ts_path].dtype == np.bool:
+                tmp_array = np.full(data_nums, np.nan, np.float32)
+                tmp_array = pd.Series(tmp_array).astype("boolean")
+            else:
+                tmp_array = np.full(data_nums, None, dtype=data[ts_path].dtype)
+            data_empty[ts_path] = tmp_array
+        df_input = pd.DataFrame(data_empty)
+
+        for row_index in range(data_nums):
+            is_row_inserted = False
+            for column_index in range(len(measurements)):
+                if random.choice([True, False]):
+                    session.insert_record(
+                        device_id,
+                        row_index,
+                        [measurements[column_index]],
+                        [data_type_lst[column_index]],
+                        [data[ts_path_lst[column_index]].tolist()[row_index]],
+                    )
+                    df_input.at[row_index, ts_path_lst[column_index]] = data[
+                        ts_path_lst[column_index]
+                    ][row_index]
+                    is_row_inserted = True
+            if not is_row_inserted:
+                column_index = 0
+                session.insert_record(
+                    device_id,
+                    row_index,
+                    [measurements[column_index]],
+                    [data_type_lst[column_index]],
+                    [data[ts_path_lst[column_index]].tolist()[row_index]],
+                )
+                df_input.at[row_index, ts_path_lst[column_index]] = data[
+                    ts_path_lst[column_index]
+                ][row_index]
+
+        df_input.insert(0, "Time", timestamps)
+
+        session_data_set = session.execute_query_statement("SELECT * FROM 
root.*")
+        df_output = session_data_set.todf()
+        df_output = df_output[df_input.columns.tolist()]
+
+        session.close()
+    assert_frame_equal(df_input, df_output)
+
+
+def test_multi_fetch():
+    with IoTDBContainer() as db:
+        db: IoTDBContainer
+        session = Session(db.get_container_host_ip(), 
db.get_exposed_port(6667))
+        session.open(False)
+
+        create_ts(session)
+
+        # insert data
+        data_nums = 990
+        data = {}
+        timestamps = np.arange(data_nums)
+        data[ts_path_lst[0]] = np.float32(np.random.rand(data_nums))
+        data[ts_path_lst[1]] = np.random.rand(data_nums)
+        data[ts_path_lst[2]] = np.random.randint(10, 100, data_nums, 
dtype="int32")
+        data[ts_path_lst[3]] = np.random.randint(10, 100, data_nums, 
dtype="int64")
+        data[ts_path_lst[4]] = np.random.choice([True, False], size=data_nums)
+        data[ts_path_lst[5]] = np.random.choice(["text1", "text2"], 
size=data_nums)
+
+        df_input = pd.DataFrame(data)
+
+        tablet = Tablet(
+            device_id, measurements, data_type_lst, df_input.values, timestamps
+        )
+        session.insert_tablet(tablet)
+
+        df_input.insert(0, "Time", timestamps)
+
+        session_data_set = session.execute_query_statement("SELECT * FROM 
root.*")
+        session_data_set.set_fetch_size(100)
+        df_output = session_data_set.todf()
+        df_output = df_output[df_input.columns.tolist()]
+
+        session.close()
+    assert_frame_equal(df_input, df_output)

Reply via email to