This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 4b87a71ba1f8b6036f172fbda573a5626f1c8482
Author: Orhan Kislal <[email protected]>
AuthorDate: Thu Feb 25 20:30:43 2021 +0300

    DL: Check if the owner of the object table is a superuser
---
 .../modules/deep_learning/madlib_keras_custom_function.py_in  |  8 +++-----
 .../modules/deep_learning/madlib_keras_validator.py_in        |  8 ++++++++
 src/ports/postgres/modules/utilities/utilities.py_in          | 11 +++++++++++
 3 files changed, 22 insertions(+), 5 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 1ebf9f6..32a5757 100644
--- 
a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -128,13 +128,11 @@ def load_custom_function(schema_madlib, object_table, 
object, name, description=
 def delete_custom_function(schema_madlib, object_table, id=None, name=None, 
**kwargs):
 
     if object_table is not None:
-        schema_name = get_schema(object_table)
-        if schema_name is None:
-            object_table = "{0}.{1}".format(schema_madlib, 
quote_ident(object_table))
-        elif schema_name != schema_madlib:
-            plpy.error("DL: Custom function table has to be in the {0} 
schema".format(schema_madlib))
+        object_table = "{0}.{1}".format(schema_madlib, 
quote_ident(object_table))
 
     input_tbl_valid(object_table, "Keras Custom Funtion")
+    _assert(is_superuser(current_user()), "DL: The user has to have admin "\
+        "privilages to delete a custom function")
     _assert(id is not None or name is not None,
             "{0}: function id/name cannot be NULL! " \
             "Use \"SELECT delete_custom_function('usage')\" for 
help.".format(module_name))
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 535d70d..ab8d336 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -40,6 +40,8 @@ from utilities.utilities import _assert
 from utilities.utilities import add_postfix
 from utilities.utilities import is_platform_pg
 from utilities.utilities import is_var_valid
+from utilities.utilities import is_superuser
+from utilities.utilities import get_table_owner
 from utilities.validate_args import cols_in_tbl_valid
 from utilities.validate_args import columns_exist_in_table
 from utilities.validate_args import get_expr_type
@@ -324,6 +326,9 @@ class FitCommonValidator(object):
 
         if self.object_table is not None:
             input_tbl_valid(self.object_table, self.module_name)
+
+            _assert(is_superuser(get_table_owner(self.object_table)),
+                "DL: Cannot use a table of a non-superuser as object table.")
             cols_in_tbl_valid(self.object_table, 
CustomFunctionSchema.col_names, self.module_name)
 
         if self.warm_start:
@@ -543,6 +548,7 @@ class MstLoaderInputValidator():
         # Default metrics, since it is not part of the builtin metrics list
         builtin_metrics.append('accuracy')
         if self.object_table is not None:
+
             res = plpy.execute("SELECT {0} from 
{1}".format(CustomFunctionSchema.FN_NAME,
                                                             self.object_table))
             for r in res:
@@ -576,6 +582,8 @@ class MstLoaderInputValidator():
         input_tbl_valid(self.model_arch_table, self.module_name)
         if self.object_table is not None:
             input_tbl_valid(self.object_table, self.module_name)
+            _assert(is_superuser(get_table_owner(self.object_table)),
+                "DL: Cannot use a table of a non-superuser as object table.")
         if self.module_name == 'load_model_selection_table' or 
self.module_name == 'madlib_keras_automl':
             output_tbl_valid(self.model_selection_table, self.module_name)
             output_tbl_valid(self.model_selection_summary_table, 
self.module_name)
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in 
b/src/ports/postgres/modules/utilities/utilities.py_in
index e5a4c3d..8fc4a28 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -775,6 +775,17 @@ def is_superuser(user):
     return plpy.execute("SELECT rolsuper FROM pg_catalog.pg_roles "\
                         "WHERE rolname = '{0}'".format(user))[0]['rolsuper']
 
+def get_table_owner(schema_table):
+
+    split_table = schema_table.split(".",1)
+    schema = split_table[0]
+    non_schema_table = split_table[1]
+
+    q = """SELECT tableowner FROM pg_catalog.pg_tables
+           WHERE schemaname='{0}' AND tablename='{1}'
+           """.format(schema, non_schema_table)
+    return plpy.execute(q)[0]['tableowner']
+
 def madlib_version(schema_madlib):
     """Returns the MADlib version string."""
     raw = plpy.execute("""

Reply via email to