Github user ictmalili commented on a diff in the pull request:

    https://github.com/apache/incubator-hawq/pull/904#discussion_r79295886
  
    --- Diff: tools/bin/hawqregister ---
    @@ -126,182 +127,319 @@ def option_parser_yml(yml_file):
         return 'AO', files, sizes, params['AO_Schema'], 
params['Distribution_Policy'], params['AO_FileLocations'], params['Bucketnum']
     
     
    -def create_table(dburl, tablename, schema_info, fmt, distrbution_policy, 
file_locations, bucket_number):
    -    try:
    -        query = "select count(*) from pg_class where relname = '%s';" % 
tablename.split('.')[-1].lower()
    -        conn = dbconn.connect(dburl, False)
    -        rows = dbconn.execSQL(conn, query)
    -        conn.commit()
    -        conn.close()
    -        for row in rows:
    -            if row[0] != 0:
    -                return False
    -    except DatabaseError, ex:
    -        logger.error('Failed to execute query "%s"' % query)
    -        sys.exit(1)
    -
    -    try:
    +class GpRegisterAccessor(object):
    +    def __init__(self, conn):
    +        self.conn = conn
    +        rows = self.exec_query("""
    +        SELECT oid, datname, dat2tablespace,
    +               pg_encoding_to_char(encoding) encoding
    +        FROM pg_database WHERE datname=current_database()""")
    +        self.dbid = rows[0]['oid']
    +        self.dbname = rows[0]['datname']
    +        self.spcid = rows[0]['dat2tablespace']
    +        self.dbencoding = rows[0]['encoding']
    +        self.dbversion = self.exec_query('select version()')[0]['version']
    +
    +    def exec_query(self, sql):
    +        '''execute query and return dict result'''
    +        return self.conn.query(sql).dictresult()
    +
    +    def get_table_existed(self, tablename):
    +        qry = """select count(*) from pg_class where relname = '%s';""" % 
tablename.split('.')[-1].lower()
    +        return self.exec_query(qry)[0]['count'] == 1
    +
    +    def do_create_table(self, tablename, schema_info, fmt, 
distrbution_policy, file_locations, bucket_number):
    +        if self.get_table_existed(tablename):
    +            return False
             schema = ','.join([k['name'] + ' ' + k['type'] for k in 
schema_info])
             fmt = 'ROW' if fmt == 'AO' else fmt
             if fmt == 'ROW':
                 query = ('create table %s(%s) with (appendonly=true, 
orientation=%s, compresstype=%s, compresslevel=%s, checksum=%s, bucketnum=%s) 
%s;'
    -                    % (tablename, schema, fmt, 
file_locations['CompressionType'], file_locations['CompressionLevel'], 
file_locations['Checksum'], bucket_number, distrbution_policy))
    +                     % (tablename, schema, fmt, 
file_locations['CompressionType'], file_locations['CompressionLevel'], 
file_locations['Checksum'], bucket_number, distrbution_policy))
             else: # Parquet
                 query = ('create table %s(%s) with (appendonly=true, 
orientation=%s, compresstype=%s, compresslevel=%s, pagesize=%s, 
rowgroupsize=%s, bucketnum=%s) %s;'
    -                    % (tablename, schema, fmt, 
file_locations['CompressionType'], file_locations['CompressionLevel'], 
file_locations['PageSize'], file_locations['RowGroupSize'], bucket_number, 
distrbution_policy))
    -        conn = dbconn.connect(dburl, False)
    -        rows = dbconn.execSQL(conn, query)
    -        conn.commit()
    -        conn.close()
    +                     % (tablename, schema, fmt, 
file_locations['CompressionType'], file_locations['CompressionLevel'], 
file_locations['PageSize'], file_locations['RowGroupSize'], bucket_number, 
distrbution_policy))
    +        self.conn.query(query)
             return True
    -    except DatabaseError, ex:
    -        print DatabaseError, ex
    -        logger.error('Failed to execute query "%s"' % query)
    -        sys.exit(1)
     
    +    def check_hash_type(self, tablename):
    +        qry = """select attrnums from gp_distribution_policy, pg_class 
where pg_class.relname = '%s' and pg_class.oid = 
gp_distribution_policy.localoid;""" % tablename
    +        rows = self.exec_query(qry)
    +        if len(rows) == 0:
    +            logger.error('Table %s not found in table 
gp_distribution_policy.' % tablename)
    +            sys.exit(1)
    +        if rows[0]['attrnums']:
    +            logger.error('Cannot register file(s) to a table which is hash 
distribuetd.')
    +            sys.exit(1)
     
    -def get_seg_name(dburl, tablename, database, fmt):
    -    try:
    -        relname = ''
    +    # pg_paqseg_#
    +    def get_seg_name(self, tablename, database, fmt):
             tablename = tablename.split('.')[-1]
             query = ("select pg_class2.relname from pg_class as pg_class1, 
pg_appendonly, pg_class as pg_class2 "
                      "where pg_class1.relname ='%s' and pg_class1.oid = 
pg_appendonly.relid and pg_appendonly.segrelid = pg_class2.oid;") % tablename
    -        conn = dbconn.connect(dburl, True)
    -        rows = dbconn.execSQL(conn, query)
    -        conn.commit()
    -        if not rows.rowcount:
    +        rows = self.exec_query(query)
    +        if len(rows) == 0:
                 logger.error('table "%s" not found in db "%s"' % (tablename, 
database))
                 sys.exit(1)
    -        for row in rows:
    -            relname = row[0]
    -        conn.close()
    -    except DatabaseError, ex:
    -        logger.error('Failed to run query "%s" with dbname "%s"' % (query, 
database))
    -        sys.exit(1)
    -    if fmt == 'Parquet':
    -        if relname.find("paq") == -1:
    -            logger.error("table '%s' is not parquet format" % tablename)
    -            sys.exit(1)
    +        relname = rows[0]['relname']
    +        if fmt == 'Parquet':
    +            if relname.find('paq') == -1:
    +                logger.error("table '%s' is not parquet format" % 
tablename)
    +                sys.exit(1)
    +        return relname
    +
    +    def get_distribution_policy_info(self, tablename):
    +        query = "select oid from pg_class where relname = '%s';" % 
tablename.split('.')[-1].lower()
    +        rows = self.exec_query(query)
    +        oid = rows[0]['oid']
    +        query = "select * from gp_distribution_policy where localoid = 
'%s';" % oid
    +        rows = self.exec_query(query)
    +        return rows[0]['attrnums']
    +
    +    def get_bucket_number(self, tablename):
    +        query = "select oid from pg_class where relname = '%s';" % 
tablename.split('.')[-1].lower()
    +        rows = self.exec_query(query)
    +        oid = rows[0]['oid']
    +        query = "select * from gp_distribution_policy where localoid = 
'%s';" % oid
    +        rows = self.exec_query(query)
    +        return rows[0]['bucketnum']
    +
    +    def get_metadata_from_database(self, tablename, seg_name):
    +        query = 'select segno from pg_aoseg.%s;' % seg_name
    +        firstsegno = len(self.exec_query(query)) + 1
    +        # get the full path of correspoding file for target table
    +        query = ("select location, 
gp_persistent_tablespace_node.tablespace_oid, database_oid, relfilenode from 
pg_class, gp_persistent_relation_node, "
    +                 "gp_persistent_tablespace_node, 
gp_persistent_filespace_node where relname = '%s' and pg_class.relfilenode = "
    +                 "gp_persistent_relation_node.relfilenode_oid and 
gp_persistent_relation_node.tablespace_oid = 
gp_persistent_tablespace_node.tablespace_oid "
    +                 "and gp_persistent_filespace_node.filespace_oid = 
gp_persistent_filespace_node.filespace_oid;") % tablename.split('.')[-1]
    +        D = self.exec_query(query)[0]
    +        tabledir = '/'.join([D['location'].strip(), 
str(D['tablespace_oid']), str(D['database_oid']), str(D['relfilenode']), ''])
    +        return firstsegno, tabledir
    +
    +    def update_catalog(self, query):
    +        self.conn.query(query)
    +
    +
    +class HawqRegister(object):
    +    def __init__(self, options, table, utility_conn, conn):
    +        self.yml = options.yml_config
    +        self.filepath = options.filepath
    +        self.database = options.database
    +        self.tablename = table
    +        self.filesize = options.filesize
    +        self.accessor = GpRegisterAccessor(conn)
    +        self.utility_accessor = GpRegisterAccessor(utility_conn)
    +        self.mode = self._init_mode(options.force, options.repair)
    +        self._init()
    +
    +    def _init_mode(self, force, repair):
    +        def table_existed():
    +            return self.accessor.get_table_existed(self.tablename)
    +
    +        if self.yml:
    +            if force:
    +                return 'force'
    +            elif repair:
    +                if not table_existed():
    +                    logger.error('--repair mode asserts the table is 
already create.')
    +                    sys.exit(1)
    +                return 'repair'
    +            else:
    +                return 'second_normal'
    +        else:
    +            return 'first'
     
    -    return relname
    +    def _init(self):
    +        def check_hash_type():
    +            self.accessor.check_hash_type(self.tablename)
     
    +        # check conflicting distributed policy
    +        def check_distribution_policy():
    +            if self.distribution_policy.startswith('DISTRIBUTED BY'):
    +                if len(self.files) % self.bucket_number != 0:
    +                    logger.error('Files to be registered must be multiple 
times to the bucket number of hash table.')
    +                    sys.exit(1)
     
    -def check_hash_type(dburl, tablename):
    -    '''Check whether target table is hash distributed, in that case simple 
insertion does not work'''
    -    try:
    -        query = "select attrnums from gp_distribution_policy, pg_class 
where pg_class.relname = '%s' and pg_class.oid = 
gp_distribution_policy.localoid;" % tablename
    -        conn = dbconn.connect(dburl, False)
    -        rows = dbconn.execSQL(conn, query)
    -        conn.commit()
    -        if not rows.rowcount:
    -            logger.error('Table %s not found in table 
gp_distribution_policy.' % tablename)
    -            sys.exit(1)
    -        for row in rows:
    -            if row[0]:
    -                logger.error('Cannot register file(s) to a table which is 
hash distribuetd.')
    +        def create_table():
    +            return self.accessor.do_create_table(self.tablename, 
self.schema, self.file_format, self.distribution_policy, self.file_locations, 
self.bucket_number)
    +
    +        def get_seg_name():
    +            return self.utility_accessor.get_seg_name(self.tablename, 
self.database, self.file_format)
    +
    +        def get_metadata():
    +            return 
self.accessor.get_metadata_from_database(self.tablename, self.seg_name)
    +
    +        def get_distribution_policy():
    +            return 
self.accessor.get_distribution_policy_info(self.tablename)
    +
    +        def check_policy_consistency():
    +            policy = get_distribution_policy() # "" or "{1,3}"
    +            if policy is None:
    +                if ' 
'.join(self.distribution_policy.strip().split()).lower() == 'distributed 
randomly':
    +                    return
    +                else:
    +                    logger.error('Distribution policy of %s is not 
consistent with previous policy.' % self.tablename)
    +                    sys.exit(1)
    +            tmp_dict = {}
    +            for i, d in enumerate(self.schema):
    +                tmp_dict[d['name']] = i + 1
    +            # 'DISTRIBUETD BY (1,3)' -> {1,3}
    +            cols = 
self.distribution_policy.strip().split()[-1].strip('(').strip(')').split(',')
    +            original_policy = ','.join([str(tmp_dict[col]) for col in 
cols])
    +            if policy.strip('{').strip('}') != original_policy:
    +                logger.error('Distribution policy of %s is not consistent 
with previous policy.' % self.tablename)
                     sys.exit(1)
    -        conn.close()
    -    except DatabaseError, ex:
    -        logger.error('Failed to execute query "%s"' % query)
    -        sys.exit(1)
     
    +        def check_bucket_number():
    +            def get_bucket_number():
    +                return self.accessor.get_bucket_number(self.tablename)
     
    -def get_metadata_from_database(dburl, tablename, seg_name):
    -    '''Get the metadata to be inserted from hdfs'''
    -    try:
    -        query = 'select segno from pg_aoseg.%s;' % seg_name
    -        conn = dbconn.connect(dburl, False)
    -        rows = dbconn.execSQL(conn, query)
    -        conn.commit()
    -        conn.close()
    -    except DatabaseError, ex:
    -        logger.error('Failed to execute query "%s"' % query)
    -        sys.exit(1)
    +            if self.bucket_number != get_bucket_number():
    +                logger.error('Bucket number of %s is not consistent with 
previous bucket number.' % self.tablename)
    +                sys.exit(1)
     
    -    firstsegno = rows.rowcount + 1
    +        if self.yml:
    +            self.file_format, self.files, self.sizes, self.schema, 
self.distribution_policy, self.file_locations, self.bucket_number = 
option_parser_yml(self.yml)
    +            self.filepath = self.files[0][:self.files[0].rfind('/')] if 
self.files else ''
    +            check_distribution_policy()
    +            if self.mode != 'force' and self.mode != 'repair':
    +                if not create_table():
    +                    self.mode = 'second_exist'
    +        else:
    +            self.file_format = 'Parquet'
    +            check_hash_type() # Usage1 only support randomly distributed 
table
    +        if not self.filepath:
    +            sys.exit(0)
     
    -    try:
    -        # get the full path of correspoding file for target table
    -        query = ("select location, 
gp_persistent_tablespace_node.tablespace_oid, database_oid, relfilenode from 
pg_class, gp_persistent_relation_node, "
    -                 "gp_persistent_tablespace_node, 
gp_persistent_filespace_node where relname = '%s' and pg_class.relfilenode = "
    -                 "gp_persistent_relation_node.relfilenode_oid and 
gp_persistent_relation_node.tablespace_oid = 
gp_persistent_tablespace_node.tablespace_oid "
    -                 "and gp_persistent_filespace_node.filespace_oid = 
gp_persistent_filespace_node.filespace_oid;") % tablename.split('.')[-1]
    -        conn = dbconn.connect(dburl, False)
    -        rows = dbconn.execSQL(conn, query)
    -        conn.commit()
    -        conn.close()
    -    except DatabaseError, ex:
    -        logger.error('Failed to execute query "%s"' % query)
    -        sys.exit(1)
    -    for row in rows:
    -        tabledir = '/'.join([row[0].strip(), str(row[1]), str(row[2]), 
str(row[3]), ''])
    -    return firstsegno, tabledir
    -
    -
    -def check_files_and_table_in_same_hdfs_cluster(filepath, tabledir):
    -    '''Check whether all the files refered by 'filepath' and the location 
corresponding to the table are in the same hdfs cluster'''
    -    if not filepath:
    -        return
    -    # check whether the files to be registered is in hdfs
    -    filesystem = filepath.split('://')
    -    if filesystem[0] != 'hdfs':
    -        logger.error('Only support to register file(s) in hdfs')
    -        sys.exit(1)
    -    fileroot = filepath.split('/')
    -    tableroot = tabledir.split('/')
    -    # check the root url of them. eg: for 
'hdfs://localhost:8020/temp/tempfile', we check 'hdfs://localohst:8020'
    -    if fileroot[0:3] != tableroot[0:3]:
    -        logger.error("Files to be registered and the table are not in the 
same hdfs cluster.\nFile(s) to be registered: '%s'\nTable path in HDFS: '%s'" % 
(filepath, tabledir))
    -        sys.exit(1)
    +        self.seg_name = get_seg_name()
    +        self.firstsegno, self.tabledir = get_metadata()
     
    +        if self.mode == 'repair':
    +            if self.tabledir.strip('/') != self.filepath.strip('/'):
    +                logger.error('In repair mode, tablename in yml file should 
be the same with input args')
    --- End diff --
    
    Could we print the input args here? 


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to