Repository: madlib
Updated Branches:
  refs/heads/master 3c443e14f -> 3e519dcce


Minibatch Preprocessor: change default buffer size formula to fit grouping

- This commit changes the previous calculation formula for default buffer
  size. Previously, we used num_rows_processed/num_of_segments to indicate
  data distribution in each segment. To adjust this to grouping
  scenario, we use avg_num_rows_processed/num_of_segment to indicate data
  distribution when there are more than one groups of data. Other code changes
  are due to this change.
- This commit also modifies get_seg_number() to only get the number of
  primary segments. Previously, this function was returning the total
  segment number including master segment. This commit changes it to
  only get the primary segment number.

Closes #256


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/3e519dcc
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/3e519dcc
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/3e519dcc

Branch: refs/heads/master
Commit: 3e519dcce66d0dd4bfcc3c45f246f476c26e26d7
Parents: 3c443e1
Author: Jingyi Mei <j...@pivotal.io>
Authored: Tue Apr 3 17:50:57 2018 -0700
Committer: Nandish Jayaram <njaya...@apache.org>
Committed: Tue Apr 10 15:40:57 2018 -0700

----------------------------------------------------------------------
 .../utilities/minibatch_preprocessing.py_in     | 55 +++++++++++++-------
 .../test/minibatch_preprocessing.sql_in         |  4 +-
 .../test_minibatch_preprocessing.py_in          | 54 ++++++++++---------
 .../postgres/modules/utilities/utilities.py_in  | 11 ++--
 4 files changed, 72 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/3e519dcc/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index 4a1c8ae..401323e 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -89,14 +89,14 @@ class MiniBatchPreProcessor:
                                              self.grouping_cols,
                                              self.output_standardization_table)
 
-        num_rows_processed, num_missing_rows_skipped = self.\
-                                                
_get_skipped_rows_processed_count(
-                                                dep_var_array_str,
-                                                indep_var_array_str)
+        total_num_rows_processed, avg_num_rows_processed, \
+        num_missing_rows_skipped = self._get_skipped_rows_processed_count(
+                                            dep_var_array_str,
+                                            indep_var_array_str)
         calculated_buffer_size = MiniBatchBufferSizeCalculator.\
                                          calculate_default_buffer_size(
                                          self.buffer_size,
-                                         num_rows_processed,
+                                         avg_num_rows_processed,
                                          
standardizer.independent_var_dimension)
         """
         This query does the following:
@@ -175,7 +175,7 @@ class MiniBatchPreProcessor:
             dependent_var_dbtype,
             calculated_buffer_size,
             dep_var_classes_str,
-            num_rows_processed,
+            total_num_rows_processed,
             num_missing_rows_skipped,
             self.grouping_cols
             )
@@ -211,27 +211,42 @@ class MiniBatchPreProcessor:
         # Note: Keep the null checking where clause of this query in sync with
         # the main create output table query.
         query = """
-                SELECT COUNT(*) AS source_table_row_count,
-                sum(CASE WHEN
+            SELECT SUM(source_table_row_count_by_group) AS 
source_table_row_count,
+                   SUM(num_rows_processed_by_group) AS 
total_num_rows_processed,
+                   AVG(num_rows_processed_by_group) AS avg_num_rows_processed
+            FROM (
+                SELECT COUNT(*) AS source_table_row_count_by_group,
+                SUM(CASE WHEN
                 NOT {schema_madlib}.array_contains_null({dep_var_array})
                 AND NOT {schema_madlib}.array_contains_null({indep_var_array})
-                THEN 1 ELSE 0 END) AS num_rows_processed
+                THEN 1 ELSE 0 END) AS num_rows_processed_by_group
                 FROM {source_table}
+                {group_by_clause}) s
         """.format(
         schema_madlib = self.schema_madlib,
         source_table = self.source_table,
         dep_var_array = dep_var_array,
-        indep_var_array = indep_var_array)
+        indep_var_array = indep_var_array,
+        group_by_clause = "GROUP BY {0}".format(self.grouping_cols) \
+                          if self.grouping_cols else '')
         result = plpy.execute(query)
 
-        source_table_row_count = result[0]['source_table_row_count']
-        num_rows_processed = result[0]['num_rows_processed']
-        if not source_table_row_count or not num_rows_processed:
+        ## SUM and AVG both return float, and we have to cast them into int fo
+        ## summary table. For avg_num_rows_processed we need to ceil first so
+        ## that the minimum won't be 0
+        source_table_row_count = int(result[0]['source_table_row_count'])
+        total_num_rows_processed = int(result[0]['total_num_rows_processed'])
+        avg_num_rows_processed = int(ceil(result[0]['avg_num_rows_processed']))
+        if not source_table_row_count or not total_num_rows_processed or \
+        not avg_num_rows_processed:
             plpy.error("Error while getting the row count of the source table"
                        "{0}".format(self.source_table))
-        num_missing_rows_skipped = source_table_row_count - num_rows_processed
 
-        return num_rows_processed, num_missing_rows_skipped
+        num_missing_rows_skipped = source_table_row_count - 
total_num_rows_processed
+
+        return total_num_rows_processed, avg_num_rows_processed, \
+               num_missing_rows_skipped
+
 
 class MiniBatchQueryFormatter:
     """
@@ -450,7 +465,7 @@ class MiniBatchSummarizer:
                                     dependent_var_dbtype,
                                     buffer_size,
                                     class_values,
-                                    num_rows_processed,
+                                    total_num_rows_processed,
                                     num_missing_rows_skipped,
                                     grouping_cols):
         # 1. All the string columns are surrounded by "$$" to take care of
@@ -467,7 +482,7 @@ class MiniBatchSummarizer:
             $${dependent_var_dbtype}$$::TEXT AS dependent_vartype,
             {buffer_size} AS buffer_size,
             {class_values} AS class_values,
-            {num_rows_processed} AS num_rows_processed,
+            {total_num_rows_processed} AS num_rows_processed,
             {num_missing_rows_skipped} AS num_missing_rows_skipped,
             {grouping_cols}::TEXT AS grouping_cols
         """.format(output_summary_table = output_summary_table,
@@ -478,7 +493,7 @@ class MiniBatchSummarizer:
                    dependent_var_dbtype = dependent_var_dbtype,
                    buffer_size = buffer_size,
                    class_values = class_values,
-                   num_rows_processed = num_rows_processed,
+                   total_num_rows_processed = total_num_rows_processed,
                    num_missing_rows_skipped = num_missing_rows_skipped,
                    grouping_cols = "$$" + grouping_cols + "$$"
                                     if grouping_cols else "NULL")
@@ -491,14 +506,14 @@ class MiniBatchBufferSizeCalculator:
     """
     @staticmethod
     def calculate_default_buffer_size(buffer_size,
-                                      num_rows_processed,
+                                      avg_num_rows_processed,
                                       independent_var_dimension):
         if buffer_size is not None:
             return buffer_size
         num_of_segments = get_seg_number()
 
         default_buffer_size = min(75000000.0/independent_var_dimension,
-                                    float(num_rows_processed)/num_of_segments)
+                                    
float(avg_num_rows_processed)/num_of_segments)
         """
         1. For float number, we need at least one more buffer for the fraction 
part, e.g.
            if default_buffer_size = 0.25, we need to round it to 1.

http://git-wip-us.apache.org/repos/asf/madlib/blob/3e519dcc/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in 
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
index 04c7fb5..2f8d802 100644
--- a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
+++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
@@ -144,7 +144,7 @@ SELECT assert
         num_rows_processed  = 10 AND
         num_missing_rows_skipped    = 0 AND
         grouping_cols       = 'rings',
-        'Summary Validation failed for grouping col. Expected:' || 
__to_char(summary)
+        'Summary Validation failed for grouping col. Actual:' || 
__to_char(summary)
         ) from (select * from minibatch_preprocessing_out_summary) summary;
 
 -- Test that the standardization table gets created.
@@ -283,5 +283,5 @@ SELECT assert
         num_rows_processed  = 1 AND
         num_missing_rows_skipped    = 0 AND
         grouping_cols       = '"rin!#''gs"',
-        'Summary Validation failed for special chars. Expected:' || 
__to_char(summary)
+        'Summary Validation failed for special chars. Actual:' || 
__to_char(summary)
         ) from (select * from minibatch_preprocessing_out_summary) summary;

http://git-wip-us.apache.org/repos/asf/madlib/blob/3e519dcc/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
 
b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
index 548a6dc..879d77d 100644
--- 
a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
+++ 
b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
@@ -74,28 +74,30 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
 
     def test_minibatch_preprocessor_executes_query(self):
         preprocessor_obj = 
self.module.MiniBatchPreProcessor(self.default_schema_madlib,
-                                                         "input",
-                                                         "out",
-                                                         self.default_dep_var,
-                                                         self.default_ind_var,
-                                                         self.grouping_cols,
-                                                         
self.default_buffer_size)
+                                                             "input",
+                                                             "out",
+                                                             
self.default_dep_var,
+                                                             
self.default_ind_var,
+                                                             
self.grouping_cols,
+                                                             
self.default_buffer_size)
         self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
-                                                "num_rows_processed":3}], ""]
+                                                "total_num_rows_processed":3,
+                                                "avg_num_rows_processed": 2}], 
""]
         preprocessor_obj.minibatch_preprocessor()
         self.assertEqual(2, self.plpy_mock_execute.call_count)
         self.assertEqual(self.default_buffer_size, 
preprocessor_obj.buffer_size)
 
     def test_minibatch_preprocessor_null_buffer_size_executes_query(self):
         preprocessor_obj = 
self.module.MiniBatchPreProcessor(self.default_schema_madlib,
-                                                         "input",
-                                                         "out",
-                                                         self.default_dep_var,
-                                                         self.default_ind_var,
-                                                         self.grouping_cols,
-                                                         None)
+                                                             "input",
+                                                             "out",
+                                                             
self.default_dep_var,
+                                                             
self.default_ind_var,
+                                                             
self.grouping_cols,
+                                                             None)
         self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
-        "num_rows_processed":3}], ""]
+                                                "total_num_rows_processed":3,
+                                                "avg_num_rows_processed": 2}], 
""]
         
self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = Mock()
         preprocessor_obj.minibatch_preprocessor()
         self.assertEqual(2, self.plpy_mock_execute.call_count)
@@ -103,22 +105,22 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
     def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self):
             with self.assertRaises(plpy.PLPYException):
                 self.module.MiniBatchPreProcessor(self.default_schema_madlib,
-                                                                     
self.default_source_table,
-                                                                     
self.default_output_table,
-                                                                     "y1,y2",
-                                                                     
self.default_ind_var,
-                                                                     
self.grouping_cols,
-                                                                     
self.default_buffer_size)
+                                                  self.default_source_table,
+                                                  self.default_output_table,
+                                                  "y1,y2",
+                                                  self.default_ind_var,
+                                                  self.grouping_cols,
+                                                  self.default_buffer_size)
 
     def test_minibatch_preprocessor_buffer_size_zero_fails(self):
         with self.assertRaises(plpy.PLPYException):
             self.module.MiniBatchPreProcessor(self.default_schema_madlib,
-                                                             
self.default_source_table,
-                                                             
self.default_output_table,
-                                                             
self.default_dep_var,
-                                                             
self.default_ind_var,
-                                                             
self.grouping_cols,
-                                                             0)
+                                              self.default_source_table,
+                                              self.default_output_table,
+                                              self.default_dep_var,
+                                              self.default_ind_var,
+                                              self.grouping_cols,
+                                              0)
 
     def test_minibatch_preprocessor_buffer_size_one_passes(self):
         #not sure how to assert that an exception has not been raised

http://git-wip-us.apache.org/repos/asf/madlib/blob/3e519dcc/src/ports/postgres/modules/utilities/utilities.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in 
b/src/ports/postgres/modules/utilities/utilities.py_in
index 320082c..40ca40a 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -36,16 +36,19 @@ def is_platform_hawq():
 
 
 def get_seg_number():
-    """ Find out how many primary segments exist in the distribution
-        Might be useful for partitioning data.
+    """ Find out how many primary segments(not include master segment) exist
+        in the distribution. Might be useful for partitioning data.
     """
     if is_platform_pg():
         return 1
     else:
-        return plpy.execute("""
+        count = plpy.execute("""
             SELECT count(*) from gp_segment_configuration
-            WHERE role = 'p'
+            WHERE role = 'p' and content != -1
             """)[0]['count']
+        ## in case some weird gpdb configuration happens, always returns
+        ## primary segment number >= 1
+        return max(1, count)
 # 
------------------------------------------------------------------------------
 
 

Reply via email to