reductionista commented on a change in pull request #467: DL: Improve
performance of mini-batch preprocessor
URL: https://github.com/apache/madlib/pull/467#discussion_r363576285
##########
File path:
src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
##########
@@ -235,124 +281,324 @@ class InputDataPreprocessorDL(object):
dep_shape = self._get_dependent_var_shape()
dep_shape = ','.join([str(i) for i in dep_shape])
+ one_hot_dep_var_array_expr = self.get_one_hot_encoded_dep_var_expr()
+
+ # skip normalization step if normalizing_const = 1.0
+ if self.normalizing_const and (self.normalizing_const < 0.999999 or
self.normalizing_const > 1.000001):
+ rescale_independent_var =
"""{self.schema_madlib}.array_scalar_mult(
+
{self.independent_varname}::{FLOAT32_SQL_TYPE}[],
+
(1/{self.normalizing_const})::{FLOAT32_SQL_TYPE})
+
""".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE, **locals())
+ else:
+ self.normalizing_const = DEFAULT_NORMALIZING_CONST
+ rescale_independent_var =
"{self.independent_varname}::{FLOAT32_SQL_TYPE}[]".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE,
**locals())
+
+ # It's important that we shuffle all rows before batching for fit(),
but
+ # we can skip that for predict()
+ order_by_clause = " ORDER BY RANDOM()" if order_by_random else ""
+
if is_platform_pg():
+ # used later for writing summary table
+ self.distribution_rules = '$__madlib__$all_segments$__madlib__$'
+
+ #
+ # For postgres, we just need 3 simple queries:
+ # 1-hot-encode/normalize + batching + bytea conversion
+ #
+
+ # see note in gpdb code branch (lower down) on
+ # 1-hot-encoding of dependent var
+ one_hot_sql = """
+ CREATE TEMP TABLE {normalized_tbl} AS SELECT
+ (ROW_NUMBER() OVER({order_by_clause}) - 1)::INTEGER as
row_id,
+ {rescale_independent_var} AS x_norm,
+ {one_hot_dep_var_array_expr} AS y
+ FROM {self.source_table}
+ """.format(**locals())
+
+ plpy_execute(one_hot_sql)
+
+ self.buffer_size = self._get_buffer_size(1)
+
+ make_buffer_id = 'row_id / {0} AS '.format(self.buffer_size)
+
+ dist_by_buffer_id = ''
+ self.run_batch_rows_query(locals())
+ plpy.execute("DROP TABLE {0}".format(normalized_tbl))
+
+ dist_by_dist_key = ''
+ dist_key_col_comma = ''
+ self.convert_to_bytea(locals())
Review comment:
I ended up just removing both of those helper functions. After the
refactor, it should be cleaner now. And no more passing locals() to anything
but `plpy.execute()`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services