This is an automated email from the ASF dual-hosted git repository. khannaekta pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 4e4a143af9537a8560716567061c75ffb4bbdd5c Author: Ekta Khanna <[email protected]> AuthorDate: Mon Mar 9 10:58:12 2020 -0700 Control: Add a new wrapper for setting GUC values This commit adds a new wrapper, that sets/unsets specified GUCs to the value passed in. The GUC value is reset back to the value on exit. Co-authored-by: Nikhil Kak <[email protected]> --- src/ports/postgres/modules/utilities/control.py_in | 53 ++++++++++++++++++++++ .../utilities/test/unit_tests/test_control.py_in | 30 ++++++++++++ 2 files changed, 83 insertions(+) diff --git a/src/ports/postgres/modules/utilities/control.py_in b/src/ports/postgres/modules/utilities/control.py_in index 13b22f0..b52a881 100644 --- a/src/ports/postgres/modules/utilities/control.py_in +++ b/src/ports/postgres/modules/utilities/control.py_in @@ -44,6 +44,59 @@ class ContextDecorator(object): return wrapper +class SetGUC(ContextDecorator): + """ + @brief: A wrapper that sets/unsets GUCs and then sets it + back to the original value on exit + + This context manager sets the specified GUC to the value passed in + """ + + def __init__(self, guc_name, new_guc_value, error_on_fail=True): + self.guc_name = guc_name + self.new_guc_value = new_guc_value + if not self.guc_name or not self.new_guc_value: + plpy.error("Both guc_name and new_guc_value need to have a non null" + "value") + self.error_on_fail = error_on_fail + self.guc_exists = True + self.old_value = None + + def __enter__(self): + if self.guc_exists: + # check if allowed to change the GUC + try: + show_query = "show {0}".format(self.guc_name) + self.old_value = plpy.execute(show_query)[0] + self.old_value = self.old_value["{0}".format(self.guc_name)] + except plpy.SPIError: + self.guc_exists = False + return self + + if self.new_guc_value: + plpy.execute("set {0}={1}".format(self.guc_name, self.new_guc_value)) + else: + if self.error_on_fail: + plpy.error("Cannot set {0} to None. Please provide a valid value" + .format(self.guc_name)) + plpy.error("Unable to change '{0}' value. " + "Set '{0} = \'{1}\'' to proceed.". + format(self.guc_name, self.guc_value)) + return self + + def __exit__(self, *args): + if args and args[0]: + # an exception was raised in code, return False so that any + # exception is re-raised after exit. The transaction will not + # commit leading to reset of any change to parameter. + return False + else: + if self.guc_exists and self.old_value: + pass + plpy.execute("set {0}='{1}'". + format(self.guc_name, self.old_value)) + + class OptimizerControl(ContextDecorator): """ @brief: A wrapper that enables/disables the optimizer and diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in index 66a429e..2d43968 100644 --- a/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in +++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_control.py_in @@ -76,6 +76,36 @@ class ControlTestCase(unittest.TestCase): with self.subject.AOControl(True) as C: self.assertFalse(C.guc_exists) +class SetGUCTestCase(unittest.TestCase): + def setUp(self): + patches = { + 'plpy': plpy + } + self.plpy_mock_execute = MagicMock() + plpy.execute = self.plpy_mock_execute + + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + + import control + self.subject = control + + def tearDown(self): + self.module_patcher.stop() + + def test_set_guc_sets_new_value(self): + self.plpy_mock_execute.return_value = [{'foo': 'new_bar'}] + with self.subject.SetGUC("foo", "new_bar") as C: + self.assertTrue("new_bar", C.new_guc_value) + self.plpy_mock_execute.assert_called_with( + "set foo='new_bar'") + + def test_set_guc_missing(self): + self.plpy_mock_execute.side_effect = plpy.SPIError( + 'Unrecognized configuration parameter "foo"') + with self.subject.SetGUC("foo", "new_bar") as C: + self.assertFalse(C.guc_exists) + if __name__ == '__main__': unittest.main()
