orhankislal commented on a change in pull request #467: DL: Improve performance
of mini-batch preprocessor
URL: https://github.com/apache/madlib/pull/467#discussion_r363554492
##########
File path:
src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
##########
@@ -235,124 +281,324 @@ class InputDataPreprocessorDL(object):
dep_shape = self._get_dependent_var_shape()
dep_shape = ','.join([str(i) for i in dep_shape])
+ one_hot_dep_var_array_expr = self.get_one_hot_encoded_dep_var_expr()
+
+ # skip normalization step if normalizing_const = 1.0
+ if self.normalizing_const and (self.normalizing_const < 0.999999 or
self.normalizing_const > 1.000001):
+ rescale_independent_var =
"""{self.schema_madlib}.array_scalar_mult(
+
{self.independent_varname}::{FLOAT32_SQL_TYPE}[],
+
(1/{self.normalizing_const})::{FLOAT32_SQL_TYPE})
+
""".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE, **locals())
+ else:
+ self.normalizing_const = DEFAULT_NORMALIZING_CONST
+ rescale_independent_var =
"{self.independent_varname}::{FLOAT32_SQL_TYPE}[]".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE,
**locals())
+
+ # It's important that we shuffle all rows before batching for fit(),
but
+ # we can skip that for predict()
+ order_by_clause = " ORDER BY RANDOM()" if order_by_random else ""
+
if is_platform_pg():
+ # used later for writing summary table
+ self.distribution_rules = '$__madlib__$all_segments$__madlib__$'
+
+ #
+ # For postgres, we just need 3 simple queries:
+ # 1-hot-encode/normalize + batching + bytea conversion
+ #
+
+ # see note in gpdb code branch (lower down) on
+ # 1-hot-encoding of dependent var
+ one_hot_sql = """
+ CREATE TEMP TABLE {normalized_tbl} AS SELECT
+ (ROW_NUMBER() OVER({order_by_clause}) - 1)::INTEGER as
row_id,
+ {rescale_independent_var} AS x_norm,
+ {one_hot_dep_var_array_expr} AS y
+ FROM {self.source_table}
+ """.format(**locals())
+
+ plpy_execute(one_hot_sql)
+
+ self.buffer_size = self._get_buffer_size(1)
+
+ make_buffer_id = 'row_id / {0} AS '.format(self.buffer_size)
+
+ dist_by_buffer_id = ''
+ self.run_batch_rows_query(locals())
+ plpy.execute("DROP TABLE {0}".format(normalized_tbl))
+
+ dist_by_dist_key = ''
+ dist_key_col_comma = ''
+ self.convert_to_bytea(locals())
Review comment:
This function is renamed so we should update this line. As is, this
shouldn't work on postgres.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services