reductionista commented on a change in pull request #467: DL: Improve 
performance of mini-batch preprocessor
URL: https://github.com/apache/madlib/pull/467#discussion_r363001360
 
 

 ##########
 File path: 
src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
 ##########
 @@ -235,124 +281,324 @@ class InputDataPreprocessorDL(object):
         dep_shape = self._get_dependent_var_shape()
         dep_shape = ','.join([str(i) for i in dep_shape])
 
+        one_hot_dep_var_array_expr = self.get_one_hot_encoded_dep_var_expr()
+
+        # skip normalization step if normalizing_const = 1.0
+        if self.normalizing_const and (self.normalizing_const < 0.999999 or 
self.normalizing_const > 1.000001):
+            rescale_independent_var = 
"""{self.schema_madlib}.array_scalar_mult(
+                                         
{self.independent_varname}::{FLOAT32_SQL_TYPE}[],
+                                         
(1/{self.normalizing_const})::{FLOAT32_SQL_TYPE})
+                                      
""".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE, **locals())
+        else:
+            self.normalizing_const = DEFAULT_NORMALIZING_CONST
+            rescale_independent_var = 
"{self.independent_varname}::{FLOAT32_SQL_TYPE}[]".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE,
  **locals())
+
+        # It's important that we shuffle all rows before batching for fit(), 
but
+        #  we can skip that for predict()
+        order_by_clause = " ORDER BY RANDOM()" if order_by_random else ""
+
         if is_platform_pg():
+            # used later for writing summary table
+            self.distribution_rules = '$__madlib__$all_segments$__madlib__$'
+
+            #
+            # For postgres, we just need 3 simple queries:
+            #   1-hot-encode/normalize + batching + bytea conversion
+            #
+
+            # see note in gpdb code branch (lower down) on
+            # 1-hot-encoding of dependent var
+            one_hot_sql = """
+                CREATE TEMP TABLE {normalized_tbl} AS SELECT
+                    (ROW_NUMBER() OVER({order_by_clause}) - 1)::INTEGER as 
row_id,
+                    {rescale_independent_var} AS x_norm,
+                    {one_hot_dep_var_array_expr} AS y
+                FROM {self.source_table}
+            """.format(**locals())
+
+            plpy_execute(one_hot_sql)
+
+            self.buffer_size = self._get_buffer_size(1)
+
+            make_buffer_id = 'row_id / {0} AS '.format(self.buffer_size)
+
+            dist_by_buffer_id = ''
+            self.run_batch_rows_query(locals())
+            plpy.execute("DROP TABLE {0}".format(normalized_tbl))
+
+            dist_by_dist_key = ''
+            dist_key_col_comma = ''
+            self.convert_to_bytea(locals())
 
 Review comment:
   In the description of the function there is a comment with a list of local 
variables that need to be set before the function is called.  Although I notice 
it's not fully up to date, so I'll add the missing ones to the list.
   
   The motivation for making these two helper functions was to avoid having a 
long complicated sql query defined in two different places (one in the postgres 
codepath, another in greenplum codepath).  At one point I think we had it 
working in a different way, where a template string gets partially filled in 
near the beginning and then the rest is filled in later-I think we had some 
problem with that, but I'll look at it again and see if I can get it to work 
without introducing other problems.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to