Repository: madlib
Updated Branches:
  refs/heads/master 32cce1a16 -> cefd15eac


http://git-wip-us.apache.org/repos/asf/madlib/blob/cefd15ea/src/madpack/upgrade_util.py
----------------------------------------------------------------------
diff --git a/src/madpack/upgrade_util.py b/src/madpack/upgrade_util.py
index c521f2e..53051fe 100644
--- a/src/madpack/upgrade_util.py
+++ b/src/madpack/upgrade_util.py
@@ -1,17 +1,23 @@
+from collections import defaultdict, Iterable
+import glob
+import os
 import re
 import yaml
-from collections import defaultdict
-import os
 
 from utilities import is_rev_gte
 from utilities import get_rev_num
+from utilities import run_query
+from utilities import get_dbver
 
-def run_sql(sql, portid, con_args):
-    """
-    @brief Wrapper function for run_query
-    """
-    from madpack import run_query
-    return run_query(sql, True, con_args)
+if not __name__ == "__main__":
+    def run_sql(sql, portid, con_args):
+        """
+        @brief Wrapper function for run_query
+        """
+        return run_query(sql, con_args, True)
+else:
+    def run_sql(sql, portid, con_args):
+        return [{'dummy': 0}]
 
 
 def get_signature_for_compare(schema, proname, rettype, argument):
@@ -28,26 +34,34 @@ class UpgradeBase:
     """
     @brief Base class for handling the upgrade
     """
+
     def __init__(self, schema, portid, con_args):
         self._schema = schema.lower()
         self._portid = portid
         self._con_args = con_args
         self._schema_oid = None
         self._get_schema_oid()
+        self._dbver = get_dbver(self._con_args, self._portid)
 
     """
     @brief Wrapper function for run_sql
     """
+
     def _run_sql(self, sql):
         return run_sql(sql, self._portid, self._con_args)
 
     """
     @brief Get the oids of some objects from the catalog in the current version
     """
+
     def _get_schema_oid(self):
-        self._schema_oid = self._run_sql("""
-            SELECT oid FROM pg_namespace WHERE nspname = '{schema}'
-            """.format(schema=self._schema))[0]['oid']
+        res = self._run_sql("SELECT oid FROM pg_namespace WHERE nspname = 
'{0}'".
+                            format(self._schema))[0]
+        if 'oid' in res:
+            self._schema_oid = res['oid']
+        else:
+            self._schema_oid = None
+        return self._schema_oid
 
     def _get_function_info(self, oid):
         """
@@ -92,12 +106,19 @@ class ChangeHandler(UpgradeBase):
     @brief This class reads changes from the configuration file and handles
     the dropping of objects
     """
-    def __init__(self, schema, portid, con_args, maddir, mad_dbrev, is_hawq2):
+
+    def __init__(self, schema, portid, con_args, maddir, mad_dbrev,
+                 is_hawq2, upgrade_to=None):
         UpgradeBase.__init__(self, schema, portid, con_args)
+
+        # FIXME: maddir includes the '/src' folder. It's supposed to be the
+        # parent of that directory.
         self._maddir = maddir
         self._mad_dbrev = mad_dbrev
         self._is_hawq2 = is_hawq2
         self._newmodule = {}
+        self._curr_rev = self._get_current_version() if not upgrade_to else 
upgrade_to
+
         self._udt = {}
         self._udf = {}
         self._uda = {}
@@ -106,7 +127,19 @@ class ChangeHandler(UpgradeBase):
         self._udoc = {}
         self._load()
 
-    def _load_config_param(self, config_iterable):
+    def _get_current_version(self):
+        """ Get current version of MADlib
+
+        This currently assumes that version is available in
+        '$MADLIB_HOME/src/config/Version.yml'
+        """
+        version_filepath = os.path.abspath(
+            os.path.join(self._maddir, 'config', 'Version.yml'))
+        with open(version_filepath) as ver_file:
+            version_str = str(yaml.load(ver_file)['version'])
+            return get_rev_num(version_str)
+
+    def _load_config_param(self, config_iterable, output_config_dict=None):
         """
         Replace schema_madlib with the appropriate schema name and
         make all function names lower case to ensure ease of comparison.
@@ -114,20 +147,18 @@ class ChangeHandler(UpgradeBase):
         Args:
             @param config_iterable is an iterable of dictionaries, each with
                         key = object name (eg. function name) and value = 
details
-                        for the object. The details for the object are assumed 
to
-                        be in a dictionary with following keys:
+                        for the object. The details for the object are assumed
+                        to be in a dictionary with following keys:
                             rettype: Return type
                             argument: List of arguments
 
         Returns:
             A dictionary that lists all specific objects (functions, 
aggregates, etc)
             with object name as key and a list as value, where the list
-            contains all the items present in
-
-            another dictionary with objects details
-            as the value.
+            contains all the items present in another dictionary with objects
+            details as the value.
         """
-        _return_obj = defaultdict(list)
+        _return_obj = defaultdict(list) if not output_config_dict else 
output_config_dict
         if config_iterable is not None:
             for each_config in config_iterable:
                 for obj_name, obj_details in each_config.iteritems():
@@ -138,38 +169,105 @@ class ChangeHandler(UpgradeBase):
                     _return_obj[obj_name].append(formatted_obj)
         return _return_obj
 
+    @classmethod
+    def _add_to_dict(cls, src_dict, dest_dict):
+        """ Update dictionary with contents of another dictionary
+
+        This function performs the same function as dict.update except it adds
+        to an existing value (instead of replacing it) if the value is an
+        Iterable.
+        """
+        if src_dict:
+            for k, v in src_dict.items():
+                if k in dest_dict:
+                    if (isinstance(dest_dict[k], Iterable) and isinstance(v, 
Iterable)):
+                        dest_dict[k] += v
+                    elif isinstance(dest_dict[k], Iterable):
+                        dest_dict[k].append(v)
+                    else:
+                        dest_dict[k] = v
+                else:
+                    dest_dict[k] = v
+            return dest_dict
+
+    def _update_objects(self, config):
+        """ Update each upgrade object """
+        self._add_to_dict(config['new module'], self._newmodule)
+        self._add_to_dict(config['udt'], self._udt)
+        self._add_to_dict(config['udc'], self._udc)
+        self._add_to_dict(self._load_config_param(config['udf']), self._udf)
+        self._add_to_dict(self._load_config_param(config['uda']), self._uda)
+
+    def _get_relevant_filenames(self, upgrade_from):
+        """ Get all changelist files that together describe the upgrade process
+
+        Args:
+            @param upgrade_from: List. Version to upgrade from - the format is
+                expected to be per the output of get_rev_num
+
+        Details:
+            Changelist files are named in the format 
changelist_<src>_<dest>.yaml
+
+            When upgrading from 'upgrade_from_rev' to 'self._curr_rev', all
+            intermediate changelist files need to be followed to get all 
upgrade
+            steps. This function globs for such files and filters in 
changelists
+            that lie between the desired versions.
+
+            Additional verification: The function also ensures that a valid
+            upgrade path exists. Each version in the changelist files needs to
+            be seen twice (except upgrade_from and upgrade_to) for a valid 
path.
+            This is verified by performing an xor-like operation by
+            adding/deleting from a list.
+        """
+        output_filenames = []
+        upgrade_to = self._curr_rev
+
+        verify_list = [upgrade_from, upgrade_to]
+
+        # assuming that changelists are in the same directory as this file
+        glob_filter = os.path.abspath(
+            os.path.join(self._maddir, 'madpack', 'changelist*.yaml'))
+        all_changelists = glob.glob(glob_filter)
+        for each_ch in all_changelists:
+            # split file names to get dest versions
+            # Assumption: changelist format is
+            #   changelist_<src>_<dest>.yaml
+            ch_basename = os.path.splitext(os.path.basename(each_ch))[0]  # 
remove extension
+            ch_splits = ch_basename.split('_')  # underscore delineates 
sections
+            if len(ch_splits) >= 3:
+                src_version, dest_version = [get_rev_num(i) for i in 
ch_splits[1:3]]
+
+                # file is part of upgrade if
+                #     upgrade_to >= dest >= src >= upgrade_from
+                is_part_of_upgrade = (
+                    is_rev_gte(src_version, upgrade_from) and
+                    is_rev_gte(upgrade_to, dest_version))
+                if is_part_of_upgrade:
+                    for ver in (src_version, dest_version):
+                        if ver in verify_list:
+                            verify_list.remove(ver)
+                        else:
+                            verify_list.append(ver)
+                    abs_path = os.path.join(self._maddir, 'src', 'madpack', 
each_ch)
+                    output_filenames.append(abs_path)
+
+        if verify_list:
+            # any version remaining in verify_list implies upgrade path is 
broken
+            raise RuntimeError("Upgrade from {0} to {1} broken due to missing "
+                               "changelist files ({2}). ".
+                               format(upgrade_from, upgrade_to, verify_list))
+        return output_filenames
+
     def _load(self):
         """
         @brief Load the configuration file
         """
-
         rev = get_rev_num(self._mad_dbrev)
-
-        # _mad_dbrev = 1.9.1
-        if is_rev_gte([1,9,1],rev):
-            filename = os.path.join(self._maddir, 'madpack',
-                                    'changelist_1.9.1_1.12.yaml')
-        # _mad_dbrev = 1.10.0
-        elif is_rev_gte([1,10],rev):
-            filename = os.path.join(self._maddir, 'madpack',
-                                    'changelist_1.10.0_1.12.yaml')
-        # _mad_dbrev = 1.11
-        else:
-            filename = os.path.join(self._maddir, 'madpack',
-                                    'changelist.yaml')
-
-        config = yaml.load(open(filename))
-
-        self._newmodule = config['new module'] if config['new module'] else {}
-        self._udt = config['udt'] if config['udt'] else {}
-        self._udc = config['udc'] if config['udc'] else {}
-        self._udf = self._load_config_param(config['udf'])
-        self._uda = self._load_config_param(config['uda'])
-        # FIXME remove the following  special handling for HAWQ after svec is
-        # removed from catalog
-        if self._portid != 'hawq' and not self._is_hawq2:
-            self._udo = self._load_config_param(config['udo'])
-            self._udoc = self._load_config_param(config['udoc'])
+        upgrade_filenames = self._get_relevant_filenames(rev)
+        for f in upgrade_filenames:
+            with open(f) as handle:
+                config = yaml.load(handle)
+                self._update_objects(config)
 
     @property
     def newmodule(self):
@@ -259,8 +357,9 @@ class ChangeHandler(UpgradeBase):
         for opc, li in self._udoc.items():
             for e in li:
                 changed_opcs.add((opc, e['index']))
-
-        if self._portid == 'postgres':
+        gte_gpdb5 = (self._portid == 'greenplum' and
+                     is_rev_gte(get_rev_num(self._dbver), get_rev_num('5.0')))
+        if (self._portid == 'postgres' or gte_gpdb5):
             method_col = 'opcmethod'
         else:
             method_col = 'opcamid'
@@ -339,8 +438,8 @@ class ChangeHandler(UpgradeBase):
         """
         for op in self._udo:
             for value in self._udo[op]:
-                leftarg=value['leftarg'].replace('schema_madlib', self._schema)
-                rightarg=value['rightarg'].replace('schema_madlib', 
self._schema)
+                leftarg = value['leftarg'].replace('schema_madlib', 
self._schema)
+                rightarg = value['rightarg'].replace('schema_madlib', 
self._schema)
                 self._run_sql("""
                     DROP OPERATOR IF EXISTS {schema}.{op} ({leftarg}, 
{rightarg})
                     """.format(schema=self._schema, **locals()))
@@ -356,11 +455,13 @@ class ChangeHandler(UpgradeBase):
                     DROP OPERATOR CLASS IF EXISTS {schema}.{op_cls} USING 
{index}
                     """.format(schema=self._schema, **locals()))
 
+
 class ViewDependency(UpgradeBase):
     """
     @brief This class detects the direct/recursive view dependencies on MADLib
     UDFs/UDAs/UDOs defined in the current version
     """
+
     def __init__(self, schema, portid, con_args):
         UpgradeBase.__init__(self, schema, portid, con_args)
         self._view2proc = None
@@ -452,6 +553,7 @@ class ViewDependency(UpgradeBase):
     """
     @brief  Detect recursive view dependencies (view on view)
     """
+
     def _detect_recursive_view_dependency(self):
         rows = self._run_sql("""
             SELECT
@@ -499,9 +601,9 @@ class ViewDependency(UpgradeBase):
     @brief  Filter out recursive view dependencies which are independent of
     MADLib UDFs/UDAs
     """
+
     def _filter_recursive_view_dependency(self):
         # Get initial list
-        import sys
         checklist = []
         checklist.extend(self._view2proc.keys())
         checklist.extend(self._view2op.keys())
@@ -530,6 +632,7 @@ class ViewDependency(UpgradeBase):
     """
     @brief  Build the dependency graph (depender-to-dependee adjacency list)
     """
+
     def _build_dependency_graph(self, hasProcDependency=False):
         der2dee = self._view2view.copy()
         for view in self._view2proc:
@@ -554,12 +657,14 @@ class ViewDependency(UpgradeBase):
     """
     @brief Check dependencies
     """
+
     def has_dependency(self):
         return (len(self._view2proc) > 0) or (len(self._view2op) > 0)
 
     """
     @brief Get the ordered views for creation
     """
+
     def get_create_order_views(self):
         graph = self._build_dependency_graph()
         ordered_views = []
@@ -581,6 +686,7 @@ class ViewDependency(UpgradeBase):
     """
     @brief Get the ordered views for dropping
     """
+
     def get_drop_order_views(self):
         ordered_views = self.get_create_order_views()
         ordered_views.reverse()
@@ -678,10 +784,8 @@ class ViewDependency(UpgradeBase):
                 SET ROLE {owner};
                 CREATE OR REPLACE VIEW {schema}.{view} AS {definition};
                 RESET ROLE
-                """.format(
-                    schema=schema, view=view,
-                    definition=definition,
-                    owner=owner))
+                """.format(schema=schema, view=view,
+                           definition=definition, owner=owner))
 
     def _node_to_str(self, node):
         if len(node) == 2:
@@ -717,6 +821,7 @@ class TableDependency(UpgradeBase):
     @brief This class detects the table dependencies on MADLib UDTs defined in 
the
     current version
     """
+
     def __init__(self, schema, portid, con_args):
         UpgradeBase.__init__(self, schema, portid, con_args)
         self._table2type = None
@@ -836,6 +941,7 @@ class ScriptCleaner(UpgradeBase):
     @brief This class removes sql statements from a sql script which should 
not be
     executed during the upgrade
     """
+
     def __init__(self, schema, portid, con_args, change_handler):
         UpgradeBase.__init__(self, schema, portid, con_args)
         self._ch = change_handler
@@ -853,7 +959,9 @@ class ScriptCleaner(UpgradeBase):
         """
         @brief Get the existing UDOCs in the current version
         """
-        if self._portid == 'postgres':
+        gte_gpdb5 = (self._portid == 'greenplum' and
+                     is_rev_gte(get_rev_num(self._dbver), get_rev_num('5.0')))
+        if (self._portid == 'postgres' or gte_gpdb5):
             method_col = 'opcmethod'
         else:
             method_col = 'opcamid'
@@ -887,8 +995,8 @@ class ScriptCleaner(UpgradeBase):
         self._existing_udo = defaultdict(list)
         for row in rows:
             self._existing_udo[row['oprname']].append(
-                    {'leftarg': row['oprleft'],
-                     'rightarg': row['oprright']})
+                {'leftarg': row['oprleft'],
+                 'rightarg': row['oprright']})
 
     def _get_existing_uda(self):
         """
@@ -928,8 +1036,8 @@ class ScriptCleaner(UpgradeBase):
         for row in rows:
             # Consider about the overloaded aggregates
             self._existing_uda[row['proname']].append(
-                                    {'rettype': row['rettype'],
-                                     'argument': row['argument']})
+                {'rettype': row['rettype'],
+                 'argument': row['argument']})
 
     def _get_unchanged_operator_patterns(self):
         """
@@ -938,7 +1046,7 @@ class ScriptCleaner(UpgradeBase):
 
         @return unchanged = existing - changed
         """
-        self._get_existing_udo() # from the old version
+        self._get_existing_udo()  # from the old version
         operator_patterns = []
         # for all, pass the changed ones, add others to ret
         for each_udo, udo_details in self._existing_udo.items():
@@ -965,7 +1073,7 @@ class ScriptCleaner(UpgradeBase):
 
         @return unchanged = existing - changed
         """
-        self._get_existing_udoc() # from the old version
+        self._get_existing_udoc()  # from the old version
         opclass_patterns = []
         # for all, pass the changed ones, add others to ret
         for each_udoc, udoc_details in self._existing_udoc.items():
@@ -1055,6 +1163,7 @@ class ScriptCleaner(UpgradeBase):
     """
     @breif Remove "drop/create type" statements in the sql script
     """
+
     def _clean_type(self):
         # remove 'drop type'
         pattern = re.compile('DROP(\s+)TYPE(.*?);', re.DOTALL | re.IGNORECASE)
@@ -1076,6 +1185,7 @@ class ScriptCleaner(UpgradeBase):
     """
     @brief Remove "drop/create cast" statements in the sql script
     """
+
     def _clean_cast(self):
         # remove 'drop cast'
         pattern = re.compile('DROP(\s+)CAST(.*?);', re.DOTALL | re.IGNORECASE)
@@ -1102,6 +1212,7 @@ class ScriptCleaner(UpgradeBase):
     """
     @brief Remove "drop/create operator" statements in the sql script
     """
+
     def _clean_operator(self):
         # remove 'drop operator'
         pattern = re.compile('DROP\s+OPERATOR.*?PROCEDURE\s+=.*?;', re.DOTALL 
| re.IGNORECASE)
@@ -1117,6 +1228,7 @@ class ScriptCleaner(UpgradeBase):
     """
     @brief Remove "drop/create operator class" statements in the sql script
     """
+
     def _clean_opclass(self):
         # remove 'drop operator class'
         pattern = re.compile(r'DROP\s+OPERATOR\s*CLASS.*?;', re.DOTALL | 
re.IGNORECASE)
@@ -1132,6 +1244,7 @@ class ScriptCleaner(UpgradeBase):
     """
     @brief Rewrite the type
     """
+
     def _rewrite_type_in(self, arg):
         type_mapper = {
             'smallint': '(int2|smallint)',
@@ -1184,7 +1297,60 @@ class ScriptCleaner(UpgradeBase):
         self._clean_function()
         return self._sql
 
+
+import unittest
+
+
+class TestChangeHandler(unittest.TestCase):
+
+    def setUp(self):
+        self._dummy_schema = 'madlib'
+        self._dummy_portid = 1
+        self._dummy_con_args = 'x'
+        # maddir is the directory one level above current file
+        #   dirname gives the directory of current file (madpack)
+        #   join with pardir adds .. (e.g .../madpack/..)
+        #   abspath concatenates by traversing the ..
+        self.maddir = os.path.abspath(
+            os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                         os.pardir))
+        self._dummy_hawq2 = False
+
+    def tearDown(self):
+        pass
+
+    def test_invalid_path(self):
+        with self.assertRaises(RuntimeError):
+            ChangeHandler(self._dummy_schema, self._dummy_portid,
+                          self._dummy_con_args, self.maddir,
+                          '1.9', self._dummy_hawq2,
+                          upgrade_to=get_rev_num('1.12'))
+
+    def test_valid_path(self):
+        ch = ChangeHandler(self._dummy_schema, self._dummy_portid,
+                           self._dummy_con_args, self.maddir,
+                           '1.9.1', self._dummy_hawq2,
+                           upgrade_to=get_rev_num('1.12'))
+        self.assertEqual(ch.newmodule.keys(),
+                         ['knn', 'sssp', 'apsp', 'measures', 
'stratified_sample',
+                          'encode_categorical', 'bfs', 'mlp', 'pagerank',
+                          'train_test_split', 'wcc'])
+        self.assertEqual(ch.udt, {'kmeans_result': None, 'kmeans_state': None})
+        self.assertEqual(ch.udf['forest_train'],
+                         [{'argument': 'text, text, text, text, text, text, 
text, '
+                                       'integer, integer, boolean, integer, 
integer, '
+                                       'integer, integer, integer, text, 
boolean, '
+                                       'double precision',
+                           'rettype': 'void'},
+                          {'argument': 'text, text, text, text, text, text, 
text, '
+                                       'integer, integer, boolean, integer, 
integer, '
+                                       'integer, integer, integer, text, 
boolean',
+                           'rettype': 'void'},
+                          {'argument': 'text, text, text, text, text, text, 
text, '
+                                       'integer, integer, boolean, integer, 
integer, '
+                                       'integer, integer, integer, text',
+                           'rettype': 'void'}])
+
+
 if __name__ == '__main__':
-    config = yaml.load(open('changelist.yaml'))
-    for obj in ('new module', 'udt', 'udc', 'udf', 'uda', 'udo', 'udoc'):
-        print config[obj]
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/madlib/blob/cefd15ea/src/madpack/utilities.py
----------------------------------------------------------------------
diff --git a/src/madpack/utilities.py b/src/madpack/utilities.py
index e143d64..40a017a 100644
--- a/src/madpack/utilities.py
+++ b/src/madpack/utilities.py
@@ -22,10 +22,152 @@
 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
 
 from itertools import izip_longest
+import os
 import re
+import sys
+import subprocess
 import unittest
 
 
+# Some read-only variables
+this = os.path.basename(sys.argv[0])    # name of this script
+
+
+def error_(src_name, msg, stop=False):
+    """
+    Error message wrapper
+        @param msg error message
+        @param stop program exit flag
+    """
+    # Print to stdout
+    print("{0}: ERROR : {1}".format(src_name, msg))
+    # stack trace is not printed
+    if stop:
+        exit(2)
+# 
------------------------------------------------------------------------------
+
+
+def info_(src_name, msg, verbose=True):
+    """
+    Info message wrapper (verbose)
+        @param msg info message
+        @param verbose prints only if True (prevents caller from performing a 
check)
+    """
+    if verbose:
+        print("{0}: INFO : {1}".format(src_name, msg))
+# 
------------------------------------------------------------------------------
+
+
+def run_query(sql, con_args, show_error=True):
+    # Define sqlcmd
+    sqlcmd = 'psql'
+    delimiter = ' <$madlib_delimiter$> '
+
+    # Test the DB cmd line utility
+    std, err = subprocess.Popen(['which', sqlcmd], stdout=subprocess.PIPE,
+                                stderr=subprocess.PIPE).communicate()
+    if not std:
+        error_(this, "Command not found: %s" % sqlcmd, True)
+
+    # Run the query
+    runcmd = [sqlcmd,
+              '-h', con_args['host'].split(':')[0],
+              '-p', con_args['host'].split(':')[1],
+              '-d', con_args['database'],
+              '-U', con_args['user'],
+              '-F', delimiter,
+              '--no-password',
+              '--no-psqlrc',
+              '--no-align',
+              '-c', sql]
+    runenv = os.environ
+    if 'password' in con_args:
+        runenv["PGPASSWORD"] = con_args['password']
+    runenv["PGOPTIONS"] = '-c search_path=public -c client_min_messages=error'
+    std, err = subprocess.Popen(runcmd, env=runenv, stdout=subprocess.PIPE,
+                                stderr=subprocess.PIPE).communicate()
+
+    if err:
+        if show_error:
+            error_("SQL command failed: \nSQL: %s \n%s" % (sql, err), False)
+        if 'password' in err:
+            raise EnvironmentError
+        else:
+            raise Exception
+
+    # Convert the delimited output into a dictionary
+    results = []  # list of rows
+    i = 0
+    for line in std.splitlines():
+        if i == 0:
+            cols = [name for name in line.split(delimiter)]
+        else:
+            row = {}  # dict of col_name:col_value pairs
+            c = 0
+            for val in line.split(delimiter):
+                row[cols[c]] = val
+                c += 1
+            results.insert(i, row)
+        i += 1
+    # Drop the last line: "(X rows)"
+    try:
+        results.pop()
+    except Exception:
+        pass
+
+    return results
+# 
------------------------------------------------------------------------------
+
+
+def get_madlib_dbrev(con_args, schema):
+    """
+    Read MADlib version from database
+        @param con_args database conection object
+        @param schema MADlib schema name
+    """
+    try:
+        n_madlib_versions = int(run_query(
+            """
+                SELECT count(*) AS cnt FROM pg_tables
+                WHERE schemaname='{0}' AND tablename='migrationhistory'
+            """.format(schema),
+            con_args,
+            True)[0]['cnt'])
+        if n_madlib_versions > 0:
+            madlib_version = run_query(
+                """
+                    SELECT version
+                    FROM {0}.migrationhistory
+                    ORDER BY applied DESC LIMIT 1
+                """.format(schema),
+                con_args,
+                True)
+            if madlib_version:
+                return madlib_version[0]['version']
+    except Exception:
+        error_(this, "Failed reading MADlib db version", True)
+    return None
+# 
------------------------------------------------------------------------------
+
+
+def get_dbver(con_args, portid):
+    """ Read version number from database (of form X.Y) """
+    try:
+        versionStr = run_query("SELECT pg_catalog.version()", con_args, 
True)[0]['version']
+        if portid == 'postgres':
+            match = re.search("PostgreSQL[a-zA-Z\s]*(\d+\.\d+)", versionStr)
+        elif portid == 'greenplum':
+            # for Greenplum the 3rd digit is necessary to differentiate
+            # 4.3.5+ from versions < 4.3.5
+            match = re.search("Greenplum[a-zA-Z\s]*(\d+\.\d+\.\d+)", 
versionStr)
+        elif portid == 'hawq':
+            match = re.search("HAWQ[a-zA-Z\s]*(\d+\.\d+)", versionStr)
+        return None if match is None else match.group(1)
+    except Exception:
+        error_(this, "Failed reading database version", True)
+# 
------------------------------------------------------------------------------
+
+
 def is_rev_gte(left, right):
     """ Return if left >= right
 

Reply via email to