This is an automated email from the ASF dual-hosted git repository. okislal pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 4b87a71ba1f8b6036f172fbda573a5626f1c8482 Author: Orhan Kislal <[email protected]> AuthorDate: Thu Feb 25 20:30:43 2021 +0300 DL: Check if the owner of the object table is a superuser --- .../modules/deep_learning/madlib_keras_custom_function.py_in | 8 +++----- .../modules/deep_learning/madlib_keras_validator.py_in | 8 ++++++++ src/ports/postgres/modules/utilities/utilities.py_in | 11 +++++++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in index 1ebf9f6..32a5757 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in @@ -128,13 +128,11 @@ def load_custom_function(schema_madlib, object_table, object, name, description= def delete_custom_function(schema_madlib, object_table, id=None, name=None, **kwargs): if object_table is not None: - schema_name = get_schema(object_table) - if schema_name is None: - object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table)) - elif schema_name != schema_madlib: - plpy.error("DL: Custom function table has to be in the {0} schema".format(schema_madlib)) + object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table)) input_tbl_valid(object_table, "Keras Custom Funtion") + _assert(is_superuser(current_user()), "DL: The user has to have admin "\ + "privilages to delete a custom function") _assert(id is not None or name is not None, "{0}: function id/name cannot be NULL! " \ "Use \"SELECT delete_custom_function('usage')\" for help.".format(module_name)) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in index 535d70d..ab8d336 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in @@ -40,6 +40,8 @@ from utilities.utilities import _assert from utilities.utilities import add_postfix from utilities.utilities import is_platform_pg from utilities.utilities import is_var_valid +from utilities.utilities import is_superuser +from utilities.utilities import get_table_owner from utilities.validate_args import cols_in_tbl_valid from utilities.validate_args import columns_exist_in_table from utilities.validate_args import get_expr_type @@ -324,6 +326,9 @@ class FitCommonValidator(object): if self.object_table is not None: input_tbl_valid(self.object_table, self.module_name) + + _assert(is_superuser(get_table_owner(self.object_table)), + "DL: Cannot use a table of a non-superuser as object table.") cols_in_tbl_valid(self.object_table, CustomFunctionSchema.col_names, self.module_name) if self.warm_start: @@ -543,6 +548,7 @@ class MstLoaderInputValidator(): # Default metrics, since it is not part of the builtin metrics list builtin_metrics.append('accuracy') if self.object_table is not None: + res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME, self.object_table)) for r in res: @@ -576,6 +582,8 @@ class MstLoaderInputValidator(): input_tbl_valid(self.model_arch_table, self.module_name) if self.object_table is not None: input_tbl_valid(self.object_table, self.module_name) + _assert(is_superuser(get_table_owner(self.object_table)), + "DL: Cannot use a table of a non-superuser as object table.") if self.module_name == 'load_model_selection_table' or self.module_name == 'madlib_keras_automl': output_tbl_valid(self.model_selection_table, self.module_name) output_tbl_valid(self.model_selection_summary_table, self.module_name) diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index e5a4c3d..8fc4a28 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -775,6 +775,17 @@ def is_superuser(user): return plpy.execute("SELECT rolsuper FROM pg_catalog.pg_roles "\ "WHERE rolname = '{0}'".format(user))[0]['rolsuper'] +def get_table_owner(schema_table): + + split_table = schema_table.split(".",1) + schema = split_table[0] + non_schema_table = split_table[1] + + q = """SELECT tableowner FROM pg_catalog.pg_tables + WHERE schemaname='{0}' AND tablename='{1}' + """.format(schema, non_schema_table) + return plpy.execute(q)[0]['tableowner'] + def madlib_version(schema_madlib): """Returns the MADlib version string.""" raw = plpy.execute("""
