This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 56efedfe91f239d436c48c3d35fd327d1f5d999e
Author: Domino Valdano <[email protected]>
AuthorDate: Fri Apr 26 16:39:35 2019 -0700

    DL: Refactor computation of images per segment
    
    JIRA: MADLIB-1310
    Closes #378
---
 .../modules/deep_learning/madlib_keras.py_in       | 53 +++++++++++++++-------
 1 file changed, 37 insertions(+), 16 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 5bf215b..82d1069 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -89,25 +89,15 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
     # about making the fit function easier to read and maintain.
     if is_platform_pg():
         set_keras_session(use_gpu)
-        # Compute total images in dataset
-        total_images_per_seg = plpy.execute(
-            """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS total_images_per_seg
-                FROM {1}
-            """.format(dependent_varname, source_table))
-        seg_ids_train = "[]::integer[]"
-        gp_segment_id_col =  -1
     else:
-        # Compute total images on each segment
-        total_images_per_seg = plpy.execute(
-            """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS 
total_images_per_seg
-                FROM {1}
-                GROUP BY gp_segment_id
-            """.format(dependent_varname, source_table))
-        seg_ids_train = [int(each_segment["gp_segment_id"])
-                   for each_segment in total_images_per_seg]
-        gp_segment_id_col = 'gp_segment_id'
         # Disable GPU on master for gpdb
         os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
+
+    # Compute total images on each segment
+    gp_segment_id_col,\
+    seg_ids_train,\
+    total_images_per_seg = get_images_per_seg(source_table, dependent_varname)
+
     if validation_table:
         seg_ids_val, rows_per_seg_val = 
get_rows_per_seg_from_db(validation_table)
 
@@ -296,6 +286,37 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
     if is_platform_pg():
         clear_keras_session()
 
+def get_images_per_seg(source_table, dependent_varname):
+    """
+    Compute total images in each segment, by querying source_table.  For
+    postgres, this is just the total number of images in the db.
+    :param source_table:
+    :param dependent_var:
+    :return: Returns a string and two arrays
+    1. The appropriate string to use for querying segment number
+    ("gp_segment_id" for gpdb or "-1" for postgres).
+    1. An array containing all the segment numbers in ascending order
+    1. An array containing the total images on each of the segments in the
+    segment array.
+    """
+    if is_platform_pg():
+        total_images_per_seg = plpy.execute(
+            """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS total_images_per_seg
+                FROM {1}
+            """.format(dependent_varname, source_table))
+        seg_ids_train = "[]::integer[]"
+        gp_segment_id_col =  -1
+    else:
+        total_images_per_seg = plpy.execute(
+            """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS 
total_images_per_seg
+                FROM {1}
+                GROUP BY gp_segment_id
+            """.format(dependent_varname, source_table))
+        seg_ids_train = [int(each_segment["gp_segment_id"])
+                   for each_segment in total_images_per_seg]
+        gp_segment_id_col = 'gp_segment_id'
+    return gp_segment_id_col, seg_ids_train, total_images_per_seg
+ 
 def get_rows_per_seg_from_db(table_name):
     """
     This function queries the given table and returns the total rows per 
segment.

Reply via email to