Repository: madlib Updated Branches: refs/heads/master 32cce1a16 -> cefd15eac
http://git-wip-us.apache.org/repos/asf/madlib/blob/cefd15ea/src/madpack/upgrade_util.py ---------------------------------------------------------------------- diff --git a/src/madpack/upgrade_util.py b/src/madpack/upgrade_util.py index c521f2e..53051fe 100644 --- a/src/madpack/upgrade_util.py +++ b/src/madpack/upgrade_util.py @@ -1,17 +1,23 @@ +from collections import defaultdict, Iterable +import glob +import os import re import yaml -from collections import defaultdict -import os from utilities import is_rev_gte from utilities import get_rev_num +from utilities import run_query +from utilities import get_dbver -def run_sql(sql, portid, con_args): - """ - @brief Wrapper function for run_query - """ - from madpack import run_query - return run_query(sql, True, con_args) +if not __name__ == "__main__": + def run_sql(sql, portid, con_args): + """ + @brief Wrapper function for run_query + """ + return run_query(sql, con_args, True) +else: + def run_sql(sql, portid, con_args): + return [{'dummy': 0}] def get_signature_for_compare(schema, proname, rettype, argument): @@ -28,26 +34,34 @@ class UpgradeBase: """ @brief Base class for handling the upgrade """ + def __init__(self, schema, portid, con_args): self._schema = schema.lower() self._portid = portid self._con_args = con_args self._schema_oid = None self._get_schema_oid() + self._dbver = get_dbver(self._con_args, self._portid) """ @brief Wrapper function for run_sql """ + def _run_sql(self, sql): return run_sql(sql, self._portid, self._con_args) """ @brief Get the oids of some objects from the catalog in the current version """ + def _get_schema_oid(self): - self._schema_oid = self._run_sql(""" - SELECT oid FROM pg_namespace WHERE nspname = '{schema}' - """.format(schema=self._schema))[0]['oid'] + res = self._run_sql("SELECT oid FROM pg_namespace WHERE nspname = '{0}'". + format(self._schema))[0] + if 'oid' in res: + self._schema_oid = res['oid'] + else: + self._schema_oid = None + return self._schema_oid def _get_function_info(self, oid): """ @@ -92,12 +106,19 @@ class ChangeHandler(UpgradeBase): @brief This class reads changes from the configuration file and handles the dropping of objects """ - def __init__(self, schema, portid, con_args, maddir, mad_dbrev, is_hawq2): + + def __init__(self, schema, portid, con_args, maddir, mad_dbrev, + is_hawq2, upgrade_to=None): UpgradeBase.__init__(self, schema, portid, con_args) + + # FIXME: maddir includes the '/src' folder. It's supposed to be the + # parent of that directory. self._maddir = maddir self._mad_dbrev = mad_dbrev self._is_hawq2 = is_hawq2 self._newmodule = {} + self._curr_rev = self._get_current_version() if not upgrade_to else upgrade_to + self._udt = {} self._udf = {} self._uda = {} @@ -106,7 +127,19 @@ class ChangeHandler(UpgradeBase): self._udoc = {} self._load() - def _load_config_param(self, config_iterable): + def _get_current_version(self): + """ Get current version of MADlib + + This currently assumes that version is available in + '$MADLIB_HOME/src/config/Version.yml' + """ + version_filepath = os.path.abspath( + os.path.join(self._maddir, 'config', 'Version.yml')) + with open(version_filepath) as ver_file: + version_str = str(yaml.load(ver_file)['version']) + return get_rev_num(version_str) + + def _load_config_param(self, config_iterable, output_config_dict=None): """ Replace schema_madlib with the appropriate schema name and make all function names lower case to ensure ease of comparison. @@ -114,20 +147,18 @@ class ChangeHandler(UpgradeBase): Args: @param config_iterable is an iterable of dictionaries, each with key = object name (eg. function name) and value = details - for the object. The details for the object are assumed to - be in a dictionary with following keys: + for the object. The details for the object are assumed + to be in a dictionary with following keys: rettype: Return type argument: List of arguments Returns: A dictionary that lists all specific objects (functions, aggregates, etc) with object name as key and a list as value, where the list - contains all the items present in - - another dictionary with objects details - as the value. + contains all the items present in another dictionary with objects + details as the value. """ - _return_obj = defaultdict(list) + _return_obj = defaultdict(list) if not output_config_dict else output_config_dict if config_iterable is not None: for each_config in config_iterable: for obj_name, obj_details in each_config.iteritems(): @@ -138,38 +169,105 @@ class ChangeHandler(UpgradeBase): _return_obj[obj_name].append(formatted_obj) return _return_obj + @classmethod + def _add_to_dict(cls, src_dict, dest_dict): + """ Update dictionary with contents of another dictionary + + This function performs the same function as dict.update except it adds + to an existing value (instead of replacing it) if the value is an + Iterable. + """ + if src_dict: + for k, v in src_dict.items(): + if k in dest_dict: + if (isinstance(dest_dict[k], Iterable) and isinstance(v, Iterable)): + dest_dict[k] += v + elif isinstance(dest_dict[k], Iterable): + dest_dict[k].append(v) + else: + dest_dict[k] = v + else: + dest_dict[k] = v + return dest_dict + + def _update_objects(self, config): + """ Update each upgrade object """ + self._add_to_dict(config['new module'], self._newmodule) + self._add_to_dict(config['udt'], self._udt) + self._add_to_dict(config['udc'], self._udc) + self._add_to_dict(self._load_config_param(config['udf']), self._udf) + self._add_to_dict(self._load_config_param(config['uda']), self._uda) + + def _get_relevant_filenames(self, upgrade_from): + """ Get all changelist files that together describe the upgrade process + + Args: + @param upgrade_from: List. Version to upgrade from - the format is + expected to be per the output of get_rev_num + + Details: + Changelist files are named in the format changelist_<src>_<dest>.yaml + + When upgrading from 'upgrade_from_rev' to 'self._curr_rev', all + intermediate changelist files need to be followed to get all upgrade + steps. This function globs for such files and filters in changelists + that lie between the desired versions. + + Additional verification: The function also ensures that a valid + upgrade path exists. Each version in the changelist files needs to + be seen twice (except upgrade_from and upgrade_to) for a valid path. + This is verified by performing an xor-like operation by + adding/deleting from a list. + """ + output_filenames = [] + upgrade_to = self._curr_rev + + verify_list = [upgrade_from, upgrade_to] + + # assuming that changelists are in the same directory as this file + glob_filter = os.path.abspath( + os.path.join(self._maddir, 'madpack', 'changelist*.yaml')) + all_changelists = glob.glob(glob_filter) + for each_ch in all_changelists: + # split file names to get dest versions + # Assumption: changelist format is + # changelist_<src>_<dest>.yaml + ch_basename = os.path.splitext(os.path.basename(each_ch))[0] # remove extension + ch_splits = ch_basename.split('_') # underscore delineates sections + if len(ch_splits) >= 3: + src_version, dest_version = [get_rev_num(i) for i in ch_splits[1:3]] + + # file is part of upgrade if + # upgrade_to >= dest >= src >= upgrade_from + is_part_of_upgrade = ( + is_rev_gte(src_version, upgrade_from) and + is_rev_gte(upgrade_to, dest_version)) + if is_part_of_upgrade: + for ver in (src_version, dest_version): + if ver in verify_list: + verify_list.remove(ver) + else: + verify_list.append(ver) + abs_path = os.path.join(self._maddir, 'src', 'madpack', each_ch) + output_filenames.append(abs_path) + + if verify_list: + # any version remaining in verify_list implies upgrade path is broken + raise RuntimeError("Upgrade from {0} to {1} broken due to missing " + "changelist files ({2}). ". + format(upgrade_from, upgrade_to, verify_list)) + return output_filenames + def _load(self): """ @brief Load the configuration file """ - rev = get_rev_num(self._mad_dbrev) - - # _mad_dbrev = 1.9.1 - if is_rev_gte([1,9,1],rev): - filename = os.path.join(self._maddir, 'madpack', - 'changelist_1.9.1_1.12.yaml') - # _mad_dbrev = 1.10.0 - elif is_rev_gte([1,10],rev): - filename = os.path.join(self._maddir, 'madpack', - 'changelist_1.10.0_1.12.yaml') - # _mad_dbrev = 1.11 - else: - filename = os.path.join(self._maddir, 'madpack', - 'changelist.yaml') - - config = yaml.load(open(filename)) - - self._newmodule = config['new module'] if config['new module'] else {} - self._udt = config['udt'] if config['udt'] else {} - self._udc = config['udc'] if config['udc'] else {} - self._udf = self._load_config_param(config['udf']) - self._uda = self._load_config_param(config['uda']) - # FIXME remove the following special handling for HAWQ after svec is - # removed from catalog - if self._portid != 'hawq' and not self._is_hawq2: - self._udo = self._load_config_param(config['udo']) - self._udoc = self._load_config_param(config['udoc']) + upgrade_filenames = self._get_relevant_filenames(rev) + for f in upgrade_filenames: + with open(f) as handle: + config = yaml.load(handle) + self._update_objects(config) @property def newmodule(self): @@ -259,8 +357,9 @@ class ChangeHandler(UpgradeBase): for opc, li in self._udoc.items(): for e in li: changed_opcs.add((opc, e['index'])) - - if self._portid == 'postgres': + gte_gpdb5 = (self._portid == 'greenplum' and + is_rev_gte(get_rev_num(self._dbver), get_rev_num('5.0'))) + if (self._portid == 'postgres' or gte_gpdb5): method_col = 'opcmethod' else: method_col = 'opcamid' @@ -339,8 +438,8 @@ class ChangeHandler(UpgradeBase): """ for op in self._udo: for value in self._udo[op]: - leftarg=value['leftarg'].replace('schema_madlib', self._schema) - rightarg=value['rightarg'].replace('schema_madlib', self._schema) + leftarg = value['leftarg'].replace('schema_madlib', self._schema) + rightarg = value['rightarg'].replace('schema_madlib', self._schema) self._run_sql(""" DROP OPERATOR IF EXISTS {schema}.{op} ({leftarg}, {rightarg}) """.format(schema=self._schema, **locals())) @@ -356,11 +455,13 @@ class ChangeHandler(UpgradeBase): DROP OPERATOR CLASS IF EXISTS {schema}.{op_cls} USING {index} """.format(schema=self._schema, **locals())) + class ViewDependency(UpgradeBase): """ @brief This class detects the direct/recursive view dependencies on MADLib UDFs/UDAs/UDOs defined in the current version """ + def __init__(self, schema, portid, con_args): UpgradeBase.__init__(self, schema, portid, con_args) self._view2proc = None @@ -452,6 +553,7 @@ class ViewDependency(UpgradeBase): """ @brief Detect recursive view dependencies (view on view) """ + def _detect_recursive_view_dependency(self): rows = self._run_sql(""" SELECT @@ -499,9 +601,9 @@ class ViewDependency(UpgradeBase): @brief Filter out recursive view dependencies which are independent of MADLib UDFs/UDAs """ + def _filter_recursive_view_dependency(self): # Get initial list - import sys checklist = [] checklist.extend(self._view2proc.keys()) checklist.extend(self._view2op.keys()) @@ -530,6 +632,7 @@ class ViewDependency(UpgradeBase): """ @brief Build the dependency graph (depender-to-dependee adjacency list) """ + def _build_dependency_graph(self, hasProcDependency=False): der2dee = self._view2view.copy() for view in self._view2proc: @@ -554,12 +657,14 @@ class ViewDependency(UpgradeBase): """ @brief Check dependencies """ + def has_dependency(self): return (len(self._view2proc) > 0) or (len(self._view2op) > 0) """ @brief Get the ordered views for creation """ + def get_create_order_views(self): graph = self._build_dependency_graph() ordered_views = [] @@ -581,6 +686,7 @@ class ViewDependency(UpgradeBase): """ @brief Get the ordered views for dropping """ + def get_drop_order_views(self): ordered_views = self.get_create_order_views() ordered_views.reverse() @@ -678,10 +784,8 @@ class ViewDependency(UpgradeBase): SET ROLE {owner}; CREATE OR REPLACE VIEW {schema}.{view} AS {definition}; RESET ROLE - """.format( - schema=schema, view=view, - definition=definition, - owner=owner)) + """.format(schema=schema, view=view, + definition=definition, owner=owner)) def _node_to_str(self, node): if len(node) == 2: @@ -717,6 +821,7 @@ class TableDependency(UpgradeBase): @brief This class detects the table dependencies on MADLib UDTs defined in the current version """ + def __init__(self, schema, portid, con_args): UpgradeBase.__init__(self, schema, portid, con_args) self._table2type = None @@ -836,6 +941,7 @@ class ScriptCleaner(UpgradeBase): @brief This class removes sql statements from a sql script which should not be executed during the upgrade """ + def __init__(self, schema, portid, con_args, change_handler): UpgradeBase.__init__(self, schema, portid, con_args) self._ch = change_handler @@ -853,7 +959,9 @@ class ScriptCleaner(UpgradeBase): """ @brief Get the existing UDOCs in the current version """ - if self._portid == 'postgres': + gte_gpdb5 = (self._portid == 'greenplum' and + is_rev_gte(get_rev_num(self._dbver), get_rev_num('5.0'))) + if (self._portid == 'postgres' or gte_gpdb5): method_col = 'opcmethod' else: method_col = 'opcamid' @@ -887,8 +995,8 @@ class ScriptCleaner(UpgradeBase): self._existing_udo = defaultdict(list) for row in rows: self._existing_udo[row['oprname']].append( - {'leftarg': row['oprleft'], - 'rightarg': row['oprright']}) + {'leftarg': row['oprleft'], + 'rightarg': row['oprright']}) def _get_existing_uda(self): """ @@ -928,8 +1036,8 @@ class ScriptCleaner(UpgradeBase): for row in rows: # Consider about the overloaded aggregates self._existing_uda[row['proname']].append( - {'rettype': row['rettype'], - 'argument': row['argument']}) + {'rettype': row['rettype'], + 'argument': row['argument']}) def _get_unchanged_operator_patterns(self): """ @@ -938,7 +1046,7 @@ class ScriptCleaner(UpgradeBase): @return unchanged = existing - changed """ - self._get_existing_udo() # from the old version + self._get_existing_udo() # from the old version operator_patterns = [] # for all, pass the changed ones, add others to ret for each_udo, udo_details in self._existing_udo.items(): @@ -965,7 +1073,7 @@ class ScriptCleaner(UpgradeBase): @return unchanged = existing - changed """ - self._get_existing_udoc() # from the old version + self._get_existing_udoc() # from the old version opclass_patterns = [] # for all, pass the changed ones, add others to ret for each_udoc, udoc_details in self._existing_udoc.items(): @@ -1055,6 +1163,7 @@ class ScriptCleaner(UpgradeBase): """ @breif Remove "drop/create type" statements in the sql script """ + def _clean_type(self): # remove 'drop type' pattern = re.compile('DROP(\s+)TYPE(.*?);', re.DOTALL | re.IGNORECASE) @@ -1076,6 +1185,7 @@ class ScriptCleaner(UpgradeBase): """ @brief Remove "drop/create cast" statements in the sql script """ + def _clean_cast(self): # remove 'drop cast' pattern = re.compile('DROP(\s+)CAST(.*?);', re.DOTALL | re.IGNORECASE) @@ -1102,6 +1212,7 @@ class ScriptCleaner(UpgradeBase): """ @brief Remove "drop/create operator" statements in the sql script """ + def _clean_operator(self): # remove 'drop operator' pattern = re.compile('DROP\s+OPERATOR.*?PROCEDURE\s+=.*?;', re.DOTALL | re.IGNORECASE) @@ -1117,6 +1228,7 @@ class ScriptCleaner(UpgradeBase): """ @brief Remove "drop/create operator class" statements in the sql script """ + def _clean_opclass(self): # remove 'drop operator class' pattern = re.compile(r'DROP\s+OPERATOR\s*CLASS.*?;', re.DOTALL | re.IGNORECASE) @@ -1132,6 +1244,7 @@ class ScriptCleaner(UpgradeBase): """ @brief Rewrite the type """ + def _rewrite_type_in(self, arg): type_mapper = { 'smallint': '(int2|smallint)', @@ -1184,7 +1297,60 @@ class ScriptCleaner(UpgradeBase): self._clean_function() return self._sql + +import unittest + + +class TestChangeHandler(unittest.TestCase): + + def setUp(self): + self._dummy_schema = 'madlib' + self._dummy_portid = 1 + self._dummy_con_args = 'x' + # maddir is the directory one level above current file + # dirname gives the directory of current file (madpack) + # join with pardir adds .. (e.g .../madpack/..) + # abspath concatenates by traversing the .. + self.maddir = os.path.abspath( + os.path.join(os.path.dirname(os.path.realpath(__file__)), + os.pardir)) + self._dummy_hawq2 = False + + def tearDown(self): + pass + + def test_invalid_path(self): + with self.assertRaises(RuntimeError): + ChangeHandler(self._dummy_schema, self._dummy_portid, + self._dummy_con_args, self.maddir, + '1.9', self._dummy_hawq2, + upgrade_to=get_rev_num('1.12')) + + def test_valid_path(self): + ch = ChangeHandler(self._dummy_schema, self._dummy_portid, + self._dummy_con_args, self.maddir, + '1.9.1', self._dummy_hawq2, + upgrade_to=get_rev_num('1.12')) + self.assertEqual(ch.newmodule.keys(), + ['knn', 'sssp', 'apsp', 'measures', 'stratified_sample', + 'encode_categorical', 'bfs', 'mlp', 'pagerank', + 'train_test_split', 'wcc']) + self.assertEqual(ch.udt, {'kmeans_result': None, 'kmeans_state': None}) + self.assertEqual(ch.udf['forest_train'], + [{'argument': 'text, text, text, text, text, text, text, ' + 'integer, integer, boolean, integer, integer, ' + 'integer, integer, integer, text, boolean, ' + 'double precision', + 'rettype': 'void'}, + {'argument': 'text, text, text, text, text, text, text, ' + 'integer, integer, boolean, integer, integer, ' + 'integer, integer, integer, text, boolean', + 'rettype': 'void'}, + {'argument': 'text, text, text, text, text, text, text, ' + 'integer, integer, boolean, integer, integer, ' + 'integer, integer, integer, text', + 'rettype': 'void'}]) + + if __name__ == '__main__': - config = yaml.load(open('changelist.yaml')) - for obj in ('new module', 'udt', 'udc', 'udf', 'uda', 'udo', 'udoc'): - print config[obj] + unittest.main() http://git-wip-us.apache.org/repos/asf/madlib/blob/cefd15ea/src/madpack/utilities.py ---------------------------------------------------------------------- diff --git a/src/madpack/utilities.py b/src/madpack/utilities.py index e143d64..40a017a 100644 --- a/src/madpack/utilities.py +++ b/src/madpack/utilities.py @@ -22,10 +22,152 @@ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # from itertools import izip_longest +import os import re +import sys +import subprocess import unittest +# Some read-only variables +this = os.path.basename(sys.argv[0]) # name of this script + + +def error_(src_name, msg, stop=False): + """ + Error message wrapper + @param msg error message + @param stop program exit flag + """ + # Print to stdout + print("{0}: ERROR : {1}".format(src_name, msg)) + # stack trace is not printed + if stop: + exit(2) +# ------------------------------------------------------------------------------ + + +def info_(src_name, msg, verbose=True): + """ + Info message wrapper (verbose) + @param msg info message + @param verbose prints only if True (prevents caller from performing a check) + """ + if verbose: + print("{0}: INFO : {1}".format(src_name, msg)) +# ------------------------------------------------------------------------------ + + +def run_query(sql, con_args, show_error=True): + # Define sqlcmd + sqlcmd = 'psql' + delimiter = ' <$madlib_delimiter$> ' + + # Test the DB cmd line utility + std, err = subprocess.Popen(['which', sqlcmd], stdout=subprocess.PIPE, + stderr=subprocess.PIPE).communicate() + if not std: + error_(this, "Command not found: %s" % sqlcmd, True) + + # Run the query + runcmd = [sqlcmd, + '-h', con_args['host'].split(':')[0], + '-p', con_args['host'].split(':')[1], + '-d', con_args['database'], + '-U', con_args['user'], + '-F', delimiter, + '--no-password', + '--no-psqlrc', + '--no-align', + '-c', sql] + runenv = os.environ + if 'password' in con_args: + runenv["PGPASSWORD"] = con_args['password'] + runenv["PGOPTIONS"] = '-c search_path=public -c client_min_messages=error' + std, err = subprocess.Popen(runcmd, env=runenv, stdout=subprocess.PIPE, + stderr=subprocess.PIPE).communicate() + + if err: + if show_error: + error_("SQL command failed: \nSQL: %s \n%s" % (sql, err), False) + if 'password' in err: + raise EnvironmentError + else: + raise Exception + + # Convert the delimited output into a dictionary + results = [] # list of rows + i = 0 + for line in std.splitlines(): + if i == 0: + cols = [name for name in line.split(delimiter)] + else: + row = {} # dict of col_name:col_value pairs + c = 0 + for val in line.split(delimiter): + row[cols[c]] = val + c += 1 + results.insert(i, row) + i += 1 + # Drop the last line: "(X rows)" + try: + results.pop() + except Exception: + pass + + return results +# ------------------------------------------------------------------------------ + + +def get_madlib_dbrev(con_args, schema): + """ + Read MADlib version from database + @param con_args database conection object + @param schema MADlib schema name + """ + try: + n_madlib_versions = int(run_query( + """ + SELECT count(*) AS cnt FROM pg_tables + WHERE schemaname='{0}' AND tablename='migrationhistory' + """.format(schema), + con_args, + True)[0]['cnt']) + if n_madlib_versions > 0: + madlib_version = run_query( + """ + SELECT version + FROM {0}.migrationhistory + ORDER BY applied DESC LIMIT 1 + """.format(schema), + con_args, + True) + if madlib_version: + return madlib_version[0]['version'] + except Exception: + error_(this, "Failed reading MADlib db version", True) + return None +# ------------------------------------------------------------------------------ + + +def get_dbver(con_args, portid): + """ Read version number from database (of form X.Y) """ + try: + versionStr = run_query("SELECT pg_catalog.version()", con_args, True)[0]['version'] + if portid == 'postgres': + match = re.search("PostgreSQL[a-zA-Z\s]*(\d+\.\d+)", versionStr) + elif portid == 'greenplum': + # for Greenplum the 3rd digit is necessary to differentiate + # 4.3.5+ from versions < 4.3.5 + match = re.search("Greenplum[a-zA-Z\s]*(\d+\.\d+\.\d+)", versionStr) + elif portid == 'hawq': + match = re.search("HAWQ[a-zA-Z\s]*(\d+\.\d+)", versionStr) + return None if match is None else match.group(1) + except Exception: + error_(this, "Failed reading database version", True) +# ------------------------------------------------------------------------------ + + def is_rev_gte(left, right): """ Return if left >= right
