Hi Hackers,

Please find attached a patch including new python test that checks the
transaction status after executing queries and saving data when auto-commit
is on/off. It also includes some refactoring of previously written python
tests.

Please review !
Thanks. Regards.

-- 
*Yosry Muhammad Yosry*

Computer Engineering student,
The Faculty of Engineering,
Cairo University (2021).
Class representative of CMP 2021.
https://www.linkedin.com/in/yosrym93/
diff --git a/web/pgadmin/tools/sqleditor/tests/execute_query_utils.py b/web/pgadmin/tools/sqleditor/tests/execute_query_test_utils.py
similarity index 100%
rename from web/pgadmin/tools/sqleditor/tests/execute_query_utils.py
rename to web/pgadmin/tools/sqleditor/tests/execute_query_test_utils.py
diff --git a/web/pgadmin/tools/sqleditor/tests/test_transaction_status.py b/web/pgadmin/tools/sqleditor/tests/test_transaction_status.py
new file mode 100644
index 000000000..b18dcdc62
--- /dev/null
+++ b/web/pgadmin/tools/sqleditor/tests/test_transaction_status.py
@@ -0,0 +1,336 @@
+##########################################################################
+#
+# pgAdmin 4 - PostgreSQL Tools
+#
+# Copyright (C) 2013 - 2019, The pgAdmin Development Team
+# This software is released under the PostgreSQL Licence
+#
+##########################################################################
+
+import json
+import random
+
+from pgadmin.browser.server_groups.servers.databases.tests import utils as \
+    database_utils
+from pgadmin.utils.route import BaseTestGenerator
+from regression import parent_node_dict
+from regression.python_test_utils import test_utils as utils
+from pgadmin.tools.sqleditor.tests.execute_query_test_utils import execute_query
+
+from pgadmin.tools.sqleditor.utils.constant_definition \
+    import TX_STATUS_IDLE, TX_STATUS_INTRANS
+
+
+class TestTransactionControl(BaseTestGenerator):
+    """ This class will test the transaction status after various operations """
+    scenarios = [
+        ('When auto-commit is enabled, and save is successful', dict(
+            is_auto_commit_enabled=True,
+            transaction_status=TX_STATUS_IDLE,
+            save_payload={
+                "updated": {},
+                "added": {
+                    "2": {
+                        "err": False,
+                        "data": {
+                            "pk_col": "3",
+                            "__temp_PK": "2",
+                            "normal_col": "three"
+                        }
+                    }
+                },
+                "staged_rows": {},
+                "deleted": {},
+                "updated_index": {},
+                "added_index": {"2": "2"},
+                "columns": [
+                    {
+                        "name": "pk_col",
+                        "display_name": "pk_col",
+                        "column_type": "[PK] integer",
+                        "column_type_internal": "integer",
+                        "pos": 0,
+                        "label": "pk_col<br>[PK] integer",
+                        "cell": "number",
+                        "can_edit": True,
+                        "type": "integer",
+                        "not_null": True,
+                        "has_default_val": False,
+                        "is_array": False},
+                    {"name": "normal_col",
+                     "display_name": "normal_col",
+                     "column_type": "character varying",
+                     "column_type_internal": "character varying",
+                     "pos": 1,
+                     "label": "normal_col<br>character varying",
+                     "cell": "string",
+                     "can_edit": True,
+                     "type": "character varying",
+                     "not_null": False,
+                     "has_default_val": False,
+                     "is_array": False}
+                ]
+            }
+        )),
+        ('When auto-commit is disabled and save is successful', dict(
+            is_auto_commit_enabled=False,
+            transaction_status=TX_STATUS_INTRANS,
+            save_payload={
+                "updated": {},
+                "added": {
+                    "2": {
+                        "err": False,
+                        "data": {
+                            "pk_col": "3",
+                            "__temp_PK": "2",
+                            "normal_col": "three"
+                        }
+                    }
+                },
+                "staged_rows": {},
+                "deleted": {},
+                "updated_index": {},
+                "added_index": {"2": "2"},
+                "columns": [
+                    {
+                        "name": "pk_col",
+                        "display_name": "pk_col",
+                        "column_type": "[PK] integer",
+                        "column_type_internal": "integer",
+                        "pos": 0,
+                        "label": "pk_col<br>[PK] integer",
+                        "cell": "number",
+                        "can_edit": True,
+                        "type": "integer",
+                        "not_null": True,
+                        "has_default_val": False,
+                        "is_array": False},
+                    {"name": "normal_col",
+                     "display_name": "normal_col",
+                     "column_type": "character varying",
+                     "column_type_internal": "character varying",
+                     "pos": 1,
+                     "label": "normal_col<br>character varying",
+                     "cell": "string",
+                     "can_edit": True,
+                     "type": "character varying",
+                     "not_null": False,
+                     "has_default_val": False,
+                     "is_array": False}
+                ]
+            }
+        )),
+        ('When auto-commit is enabled and save fails', dict(
+            is_auto_commit_enabled=True,
+            transaction_status=TX_STATUS_IDLE,
+            save_payload={
+                "updated": {},
+                "added": {
+                    "2": {
+                        "err": False,
+                        "data": {
+                            "pk_col": "1",
+                            "__temp_PK": "2",
+                            "normal_col": "four"
+                        }
+                    }
+                },
+                "staged_rows": {},
+                "deleted": {},
+                "updated_index": {},
+                "added_index": {"2": "2"},
+                "columns": [
+                    {
+                        "name": "pk_col",
+                        "display_name": "pk_col",
+                        "column_type": "[PK] integer",
+                        "column_type_internal": "integer",
+                        "pos": 0,
+                        "label": "pk_col<br>[PK] integer",
+                        "cell": "number",
+                        "can_edit": True,
+                        "type": "integer",
+                        "not_null": True,
+                        "has_default_val": False,
+                        "is_array": False},
+                    {"name": "normal_col",
+                     "display_name": "normal_col",
+                     "column_type": "character varying",
+                     "column_type_internal": "character varying",
+                     "pos": 1,
+                     "label": "normal_col<br>character varying",
+                     "cell": "string",
+                     "can_edit": True,
+                     "type": "character varying",
+                     "not_null": False,
+                     "has_default_val": False,
+                     "is_array": False}
+                ]
+            }
+        )),
+        ('When auto-commit is disabled and save fails', dict(
+            is_auto_commit_enabled=False,
+            transaction_status=TX_STATUS_INTRANS,
+            save_payload={
+                "updated": {},
+                "added": {
+                    "2": {
+                        "err": False,
+                        "data": {
+                            "pk_col": "1",
+                            "__temp_PK": "2",
+                            "normal_col": "four"
+                        }
+                    }
+                },
+                "staged_rows": {},
+                "deleted": {},
+                "updated_index": {},
+                "added_index": {"2": "2"},
+                "columns": [
+                    {
+                        "name": "pk_col",
+                        "display_name": "pk_col",
+                        "column_type": "[PK] integer",
+                        "column_type_internal": "integer",
+                        "pos": 0,
+                        "label": "pk_col<br>[PK] integer",
+                        "cell": "number",
+                        "can_edit": True,
+                        "type": "integer",
+                        "not_null": True,
+                        "has_default_val": False,
+                        "is_array": False},
+                    {"name": "normal_col",
+                     "display_name": "normal_col",
+                     "column_type": "character varying",
+                     "column_type_internal": "character varying",
+                     "pos": 1,
+                     "label": "normal_col<br>character varying",
+                     "cell": "string",
+                     "can_edit": True,
+                     "type": "character varying",
+                     "not_null": False,
+                     "has_default_val": False,
+                     "is_array": False}
+                ]
+            }
+        )),
+    ]
+
+    def setUp(self):
+        self._initialize_database_connection()
+        self._initialize_query_tool()
+        self._initialize_urls()
+
+    def runTest(self):
+        self._create_test_table()
+        self._set_auto_commit(self.is_auto_commit_enabled)
+        self._execute_select_sql()
+        self._check_transaction_status(self.transaction_status)
+        self._save_changed_data()
+        self._check_transaction_status(self.transaction_status)
+
+        if self.transaction_status == TX_STATUS_INTRANS:
+            self._commit_transaction()
+            self._check_transaction_status(TX_STATUS_IDLE)
+
+    def tearDown(self):
+        # Disconnect the database
+        database_utils.disconnect_database(self, self.server_id, self.db_id)
+
+    def _set_auto_commit(self, auto_commit):
+        response = self.tester.post(self.auto_commit_url,
+                                    data=json.dumps(auto_commit),
+                                    content_type='html/json')
+        self.assertEquals(response.status_code, 200)
+
+    def _execute_select_sql(self):
+        is_success, _ = \
+            execute_query(tester=self.tester,
+                          query=self.select_sql,
+                          start_query_tool_url=self.start_query_tool_url,
+                          poll_url=self.poll_url)
+        self.assertEquals(is_success, True)
+
+    def _check_transaction_status(self, expected_transaction_status):
+        # Check transaction status
+        response = self.tester.get(self.status_url)
+        self.assertEquals(response.status_code, 200)
+        response_data = json.loads(response.data.decode('utf-8'))
+        transaction_status = response_data['data']['status']
+        self.assertEquals(transaction_status, expected_transaction_status)
+
+    def _save_changed_data(self):
+        response = self.tester.post(self.save_url,
+                                    data=json.dumps(self.save_payload),
+                                    content_type='html/json')
+
+        self.assertEquals(response.status_code, 200)
+
+    def _commit_transaction(self):
+        is_success, _ = \
+            execute_query(tester=self.tester,
+                          query='COMMIT;',
+                          start_query_tool_url=self.start_query_tool_url,
+                          poll_url=self.poll_url)
+        self.assertEquals(is_success, True)
+
+    def _initialize_database_connection(self):
+        database_info = parent_node_dict["database"][-1]
+        self.db_name = database_info["db_name"]
+        self.server_id = database_info["server_id"]
+
+        self.server_version = parent_node_dict["schema"][-1]["server_version"]
+
+        self.db_id = database_info["db_id"]
+        db_con = database_utils.connect_database(self,
+                                                 utils.SERVER_GROUP,
+                                                 self.server_id,
+                                                 self.db_id)
+
+        driver_version = utils.get_driver_version()
+        driver_version = float('.'.join(driver_version.split('.')[:2]))
+
+        self.is_updatable_resultset_supported =  driver_version >= 2.8
+
+        if not db_con["info"] == "Database connected.":
+            raise Exception("Could not connect to the database.")
+
+    def _initialize_query_tool(self):
+        url = '/datagrid/initialize/query_tool/{0}/{1}/{2}'.format(
+            utils.SERVER_GROUP, self.server_id, self.db_id)
+        response = self.tester.post(url)
+        self.assertEquals(response.status_code, 200)
+
+        response_data = json.loads(response.data.decode('utf-8'))
+        self.trans_id = response_data['data']['gridTransId']
+
+    def _initialize_urls(self):
+        self.start_query_tool_url = \
+            '/sqleditor/query_tool/start/{0}'.format(self.trans_id)
+        self.save_url = '/sqleditor/save/{0}'.format(self.trans_id)
+        self.poll_url = '/sqleditor/poll/{0}'.format(self.trans_id)
+        self.auto_commit_url = \
+            '/sqleditor/auto_commit/{0}'.format(self.trans_id)
+        self.status_url = '/sqleditor/status/{0}'.format(self.trans_id)
+
+    def _create_test_table(self):
+        test_table_name = "test_for_updatable_resultset" + \
+                          str(random.randint(1000, 9999))
+        create_sql = """
+                            DROP TABLE IF EXISTS "%s";
+
+                            CREATE TABLE "%s"(
+                            pk_col	INT PRIMARY KEY,
+                            normal_col VARCHAR);
+
+                            INSERT INTO "%s" VALUES
+                            (1, 'one'),
+                            (2, 'two');
+                      """ % (test_table_name,
+                             test_table_name,
+                             test_table_name)
+
+        self.select_sql = "SELECT * FROM %s" % test_table_name
+        utils.create_table_with_query(self.server, self.db_name, create_sql)
diff --git a/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py b/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py
index 6a3ea38e4..c66f931a6 100644
--- a/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py
+++ b/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py
@@ -15,7 +15,7 @@ from pgadmin.browser.server_groups.servers.databases.tests import utils as \
 from pgadmin.utils.route import BaseTestGenerator
 from regression import parent_node_dict
 from regression.python_test_utils import test_utils as utils
-from pgadmin.tools.sqleditor.tests.execute_query_utils import execute_query
+from pgadmin.tools.sqleditor.tests.execute_query_test_utils import execute_query
 
 
 class TestQueryUpdatableResultset(BaseTestGenerator):
@@ -94,30 +94,33 @@ class TestQueryUpdatableResultset(BaseTestGenerator):
         self._initialize_urls()
 
     def runTest(self):
-        # Create test table (unique for each scenario)
-        test_table_name = self._create_test_table(
-            table_has_oids=self.table_has_oids)
-        # Add test table name to the query
-        sql = self.sql % test_table_name
+        self._create_test_table(table_has_oids=self.table_has_oids)
+        response_data = self._execute_select_sql()
+        self._check_primary_keys(response_data)
+        self._check_oids(response_data)
+
+    def tearDown(self):
+        # Disconnect the database
+        database_utils.disconnect_database(self, self.server_id, self.db_id)
+
+    def _execute_select_sql(self):
+        sql = self.sql % self.test_table_name
         is_success, response_data = \
             execute_query(tester=self.tester,
                           query=sql,
                           poll_url=self.poll_url,
                           start_query_tool_url=self.start_query_tool_url)
         self.assertEquals(is_success, True)
+        return response_data
 
-        # Check primary keys
+    def _check_primary_keys(self, response_data):
         primary_keys = response_data['data']['primary_keys']
         self.assertEquals(primary_keys, self.primary_keys)
 
-        # Check oids
+    def _check_oids(self, response_data):
         has_oids = response_data['data']['has_oids']
         self.assertEquals(has_oids, self.expected_has_oids)
 
-    def tearDown(self):
-        # Disconnect the database
-        database_utils.disconnect_database(self, self.server_id, self.db_id)
-
     def _initialize_database_connection(self):
         database_info = parent_node_dict["database"][-1]
         self.db_name = database_info["db_name"]
@@ -160,8 +163,8 @@ class TestQueryUpdatableResultset(BaseTestGenerator):
         self.poll_url = '/sqleditor/poll/{0}'.format(self.trans_id)
 
     def _create_test_table(self, table_has_oids=False):
-        test_table_name = "test_for_updatable_resultset" + \
-                          str(random.randint(1000, 9999))
+        self.test_table_name = "test_for_updatable_resultset" + \
+                               str(random.randint(1000, 9999))
         create_sql = """
                             DROP TABLE IF EXISTS "%s";
 
@@ -172,7 +175,7 @@ class TestQueryUpdatableResultset(BaseTestGenerator):
                                 normal_col2 VARCHAR,
                                 PRIMARY KEY(pk_col1, pk_col2)
                             )
-                      """ % (test_table_name, test_table_name)
+                      """ % (self.test_table_name, self.test_table_name)
 
         if table_has_oids:
             create_sql += ' WITH OIDS;'
@@ -180,4 +183,3 @@ class TestQueryUpdatableResultset(BaseTestGenerator):
             create_sql += ';'
 
         utils.create_table_with_query(self.server, self.db_name, create_sql)
-        return test_table_name
diff --git a/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py b/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py
index ae0fdc49d..1134329c9 100644
--- a/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py
+++ b/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py
@@ -15,7 +15,7 @@ from pgadmin.browser.server_groups.servers.databases.tests import utils as \
 from pgadmin.utils.route import BaseTestGenerator
 from regression import parent_node_dict
 from regression.python_test_utils import test_utils as utils
-from pgadmin.tools.sqleditor.tests.execute_query_utils import execute_query
+from pgadmin.tools.sqleditor.tests.execute_query_test_utils import execute_query
 
 
 class TestSaveChangedData(BaseTestGenerator):
@@ -263,16 +263,25 @@ class TestSaveChangedData(BaseTestGenerator):
         self._initialize_urls_and_select_sql()
 
     def runTest(self):
-        # Create test table (unique for each scenario)
         self._create_test_table()
-        # Execute select sql
-        is_success, _ = \
+        self._execute_sql_query(self.select_sql)
+        self._save_changed_data()
+        self._check_saved_data()
+
+    def tearDown(self):
+        # Disconnect the database
+        database_utils.disconnect_database(self, self.server_id, self.db_id)
+
+    def _execute_sql_query(self, query):
+        is_success, response_data = \
             execute_query(tester=self.tester,
-                          query=self.select_sql,
+                          query=query,
                           start_query_tool_url=self.start_query_tool_url,
                           poll_url=self.poll_url)
         self.assertEquals(is_success, True)
+        return response_data
 
+    def _save_changed_data(self):
         # Send a request to save changed data
         response = self.tester.post(self.save_url,
                                     data=json.dumps(self.save_payload),
@@ -285,24 +294,13 @@ class TestSaveChangedData(BaseTestGenerator):
         save_status = response_data['data']['status']
         self.assertEquals(save_status, self.save_status)
 
-        # Execute check sql
-        # Add test table name to the query
+    def _check_saved_data(self):
         check_sql = self.check_sql % self.test_table_name
-        is_success, response_data = \
-            execute_query(tester=self.tester,
-                          query=check_sql,
-                          start_query_tool_url=self.start_query_tool_url,
-                          poll_url=self.poll_url)
-        self.assertEquals(is_success, True)
-
+        response_data = self._execute_sql_query(check_sql)
         # Check table for updates
         result = response_data['data']['result']
         self.assertEquals(result, self.check_result)
 
-    def tearDown(self):
-        # Disconnect the database
-        database_utils.disconnect_database(self, self.server_id, self.db_id)
-
     def _initialize_database_connection(self):
         database_info = parent_node_dict["database"][-1]
         self.db_name = database_info["db_name"]
@@ -333,7 +331,6 @@ class TestSaveChangedData(BaseTestGenerator):
         self.trans_id = response_data['data']['gridTransId']
 
     def _initialize_urls_and_select_sql(self):
-
         self.start_query_tool_url = \
             '/sqleditor/query_tool/start/{0}'.format(self.trans_id)
         self.save_url = '/sqleditor/save/{0}'.format(self.trans_id)

Reply via email to