This is an automated email from the ASF dual-hosted git repository. njayaram pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 56efedfe91f239d436c48c3d35fd327d1f5d999e Author: Domino Valdano <[email protected]> AuthorDate: Fri Apr 26 16:39:35 2019 -0700 DL: Refactor computation of images per segment JIRA: MADLIB-1310 Closes #378 --- .../modules/deep_learning/madlib_keras.py_in | 53 +++++++++++++++------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in index 5bf215b..82d1069 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in @@ -89,25 +89,15 @@ def fit(schema_madlib, source_table, model, dependent_varname, # about making the fit function easier to read and maintain. if is_platform_pg(): set_keras_session(use_gpu) - # Compute total images in dataset - total_images_per_seg = plpy.execute( - """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS total_images_per_seg - FROM {1} - """.format(dependent_varname, source_table)) - seg_ids_train = "[]::integer[]" - gp_segment_id_col = -1 else: - # Compute total images on each segment - total_images_per_seg = plpy.execute( - """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS total_images_per_seg - FROM {1} - GROUP BY gp_segment_id - """.format(dependent_varname, source_table)) - seg_ids_train = [int(each_segment["gp_segment_id"]) - for each_segment in total_images_per_seg] - gp_segment_id_col = 'gp_segment_id' # Disable GPU on master for gpdb os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + # Compute total images on each segment + gp_segment_id_col,\ + seg_ids_train,\ + total_images_per_seg = get_images_per_seg(source_table, dependent_varname) + if validation_table: seg_ids_val, rows_per_seg_val = get_rows_per_seg_from_db(validation_table) @@ -296,6 +286,37 @@ def fit(schema_madlib, source_table, model, dependent_varname, if is_platform_pg(): clear_keras_session() +def get_images_per_seg(source_table, dependent_varname): + """ + Compute total images in each segment, by querying source_table. For + postgres, this is just the total number of images in the db. + :param source_table: + :param dependent_var: + :return: Returns a string and two arrays + 1. The appropriate string to use for querying segment number + ("gp_segment_id" for gpdb or "-1" for postgres). + 1. An array containing all the segment numbers in ascending order + 1. An array containing the total images on each of the segments in the + segment array. + """ + if is_platform_pg(): + total_images_per_seg = plpy.execute( + """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS total_images_per_seg + FROM {1} + """.format(dependent_varname, source_table)) + seg_ids_train = "[]::integer[]" + gp_segment_id_col = -1 + else: + total_images_per_seg = plpy.execute( + """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS total_images_per_seg + FROM {1} + GROUP BY gp_segment_id + """.format(dependent_varname, source_table)) + seg_ids_train = [int(each_segment["gp_segment_id"]) + for each_segment in total_images_per_seg] + gp_segment_id_col = 'gp_segment_id' + return gp_segment_id_col, seg_ids_train, total_images_per_seg + def get_rows_per_seg_from_db(table_name): """ This function queries the given table and returns the total rows per segment.
