This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 4e4a143af9537a8560716567061c75ffb4bbdd5c
Author: Ekta Khanna <[email protected]>
AuthorDate: Mon Mar 9 10:58:12 2020 -0700

    Control: Add a new wrapper for setting GUC values
    
    This commit adds a new wrapper, that sets/unsets specified GUCs to
    the value passed in. The GUC value is reset back to the value on exit.
    
    Co-authored-by: Nikhil Kak <[email protected]>
---
 src/ports/postgres/modules/utilities/control.py_in | 53 ++++++++++++++++++++++
 .../utilities/test/unit_tests/test_control.py_in   | 30 ++++++++++++
 2 files changed, 83 insertions(+)

diff --git a/src/ports/postgres/modules/utilities/control.py_in 
b/src/ports/postgres/modules/utilities/control.py_in
index 13b22f0..b52a881 100644
--- a/src/ports/postgres/modules/utilities/control.py_in
+++ b/src/ports/postgres/modules/utilities/control.py_in
@@ -44,6 +44,59 @@ class ContextDecorator(object):
         return wrapper
 
 
+class SetGUC(ContextDecorator):
+    """
+    @brief: A wrapper that sets/unsets GUCs and then sets it
+        back to the original value on exit
+
+    This context manager sets the specified GUC to the value passed in
+    """
+
+    def __init__(self, guc_name, new_guc_value, error_on_fail=True):
+        self.guc_name = guc_name
+        self.new_guc_value = new_guc_value
+        if not self.guc_name or not self.new_guc_value:
+            plpy.error("Both guc_name and new_guc_value need to have a non 
null"
+                       "value")
+        self.error_on_fail = error_on_fail
+        self.guc_exists = True
+        self.old_value = None
+
+    def __enter__(self):
+        if self.guc_exists:
+            # check if allowed to change the GUC
+            try:
+                show_query = "show {0}".format(self.guc_name)
+                self.old_value = plpy.execute(show_query)[0]
+                self.old_value = self.old_value["{0}".format(self.guc_name)]
+            except plpy.SPIError:
+                self.guc_exists = False
+                return self
+
+            if self.new_guc_value:
+                plpy.execute("set {0}={1}".format(self.guc_name, 
self.new_guc_value))
+            else:
+                if self.error_on_fail:
+                    plpy.error("Cannot set {0} to None. Please provide a valid 
value"
+                               .format(self.guc_name))
+                    plpy.error("Unable to change '{0}' value. "
+                               "Set '{0} = \'{1}\'' to proceed.".
+                               format(self.guc_name, self.guc_value))
+        return self
+
+    def __exit__(self, *args):
+        if args and args[0]:
+            # an exception was raised in code, return False so that any
+            # exception is re-raised after exit. The transaction will not
+            # commit leading to reset of any change to parameter.
+            return False
+        else:
+            if self.guc_exists and self.old_value:
+                pass
+                plpy.execute("set {0}='{1}'".
+                             format(self.guc_name, self.old_value))
+
+
 class OptimizerControl(ContextDecorator):
     """
     @brief: A wrapper that enables/disables the optimizer and
diff --git 
a/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in 
b/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in
index 66a429e..2d43968 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in
@@ -76,6 +76,36 @@ class ControlTestCase(unittest.TestCase):
         with self.subject.AOControl(True) as C:
             self.assertFalse(C.guc_exists)
 
+class SetGUCTestCase(unittest.TestCase):
+    def setUp(self):
+        patches = {
+            'plpy': plpy
+        }
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+
+        import control
+        self.subject = control
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_set_guc_sets_new_value(self):
+        self.plpy_mock_execute.return_value = [{'foo': 'new_bar'}]
+        with self.subject.SetGUC("foo", "new_bar") as C:
+            self.assertTrue("new_bar", C.new_guc_value)
+        self.plpy_mock_execute.assert_called_with(
+            "set foo='new_bar'")
+
+    def test_set_guc_missing(self):
+        self.plpy_mock_execute.side_effect = plpy.SPIError(
+            'Unrecognized configuration parameter "foo"')
+        with self.subject.SetGUC("foo", "new_bar") as C:
+            self.assertFalse(C.guc_exists)
+
 
 if __name__ == '__main__':
     unittest.main()

Reply via email to