This is an automated email from the ASF dual-hosted git repository.
okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new 5f10bc8 DL: Modify multi-fit warm start to accept non-matching
mst&model tables
5f10bc8 is described below
commit 5f10bc8e72e88986cd109745dddec672fdaa1d84
Author: Orhan Kislal <[email protected]>
AuthorDate: Tue Jan 7 19:36:34 2020 -0500
DL: Modify multi-fit warm start to accept non-matching mst&model tables
JIRA: MADLIB-1400 #resolve
The warm start enforced that the model table had to have a tuple for each
mst_key in the mst table for warm start. This commit relaxes this
requirement
so that users can add as well as substract mst keys throughtout their
AutoML progress.
Closes #466
---
.../madlib_keras_fit_multiple_model.py_in | 70 ++++++++++++++--------
.../deep_learning/madlib_keras_validator.py_in | 7 ---
.../test/madlib_keras_transfer_learning.sql_in | 24 +++++---
3 files changed, 60 insertions(+), 41 deletions(-)
diff --git
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 5ce555a..273321e 100644
---
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -162,8 +162,8 @@ class FitMultipleModel():
random.shuffle(self.msts_for_schedule)
self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
self.gp_segment_id_col = '0' if is_platform_pg() else
GP_SEGMENT_ID_COLNAME
- if not self.warm_start:
- self.create_model_output_table()
+
+ self.create_model_output_table()
self.weights_to_update_tbl = unique_string(desp='weights_to_update')
self.fit_multiple_model()
reset_cuda_env(original_cuda_env)
@@ -274,12 +274,26 @@ class FitMultipleModel():
plpy.execute(mst_insert_query)
def create_model_output_table(self):
- output_table_create_query = """
- CREATE TABLE {self.model_output_table}
- ({self.mst_key_col} INTEGER PRIMARY KEY,
- {self.model_weights_col} BYTEA,
- {self.model_arch_col} JSON)
- """.format(self=self)
+ warm_start_msts = []
+ if self.warm_start:
+ plpy.execute(""" DELETE FROM {self.model_output_table}
+ WHERE {self.mst_key_col} NOT IN (
+ SELECT {self.mst_key_col} FROM
{self.model_selection_table})
+ """.format(self=self))
+ warm_start_msts = plpy.execute(
+ """ SELECT array_agg({0}) AS a FROM {1}
+ """.format(self.mst_key_col, self.model_output_table))[0]['a']
+ plpy.execute("DROP TABLE {0}".format(self.model_info_table))
+
+ else:
+ output_table_create_query = """
+ CREATE TABLE {self.model_output_table}
+ ({self.mst_key_col} INTEGER PRIMARY
KEY,
+ {self.model_weights_col} BYTEA,
+ {self.model_arch_col} JSON)
+ """.format(self=self)
+ plpy.execute(output_table_create_query)
+
info_table_create_query = """
CREATE TABLE {self.model_info_table}
({self.mst_key_col} INTEGER PRIMARY KEY,
@@ -300,39 +314,32 @@ class FitMultipleModel():
validation_loss DOUBLE PRECISION[])
""".format(self=self)
- plpy.execute(output_table_create_query)
plpy.execute(info_table_create_query)
for mst in self.msts:
model_arch, model_weights =
get_model_arch_weights(self.model_arch_table,
mst[self.model_id_col])
+
+
+ # If warm start is enabled, weights from transfer learning cannot
be
+ # used, even if a particular model doesn't have warm start weigths.
+ if self.warm_start:
+ model_weights = None
+
serialized_weights = get_initial_weights(self.model_output_table,
model_arch,
model_weights,
- False,
+ mst['mst_key'] in
warm_start_msts,
self.use_gpus,
self.accessible_gpus_for_seg
)
- model = model_from_json(model_arch)
- serialized_state = model_weights if model_weights else \
-
madlib_keras_serializer.serialize_nd_weights(model.get_weights())
-
model_size = sys.getsizeof(serialized_weights) / 1024.0
+
metrics_list = get_metrics_from_compile_param(
mst[self.compile_params_col])
is_metrics_specified = True if metrics_list else False
metrics_type = 'ARRAY{0}'.format(
metrics_list) if is_metrics_specified else 'NULL'
- output_table_insert_query = """
- INSERT INTO {self.model_output_table}(
- {self.mst_key_col},
{self.model_weights_col},
- {self.model_arch_col})
- VALUES ({mst_key}, $1, $2)
- """.format(self=self,
- mst_key=mst[self.mst_key_col])
- output_table_insert_query_prepared = plpy.prepare(
- output_table_insert_query, ["bytea", "json"])
- plpy.execute(output_table_insert_query_prepared, [
- serialized_state, model_arch])
+
info_table_insert_query = """
INSERT INTO {self.model_info_table}({self.mst_key_col},
{self.model_id_col}, {self.compile_params_col},
@@ -352,6 +359,19 @@ class FitMultipleModel():
metrics_type=metrics_type)
plpy.execute(info_table_insert_query)
+ if not mst['mst_key'] in warm_start_msts:
+ output_table_insert_query = """
+ INSERT INTO {self.model_output_table}(
+ {self.mst_key_col},
{self.model_weights_col},
+ {self.model_arch_col})
+ VALUES ({mst_key}, $1, $2)
+ """.format(self=self,
+
mst_key=mst[self.mst_key_col])
+ output_table_insert_query_prepared = plpy.prepare(
+ output_table_insert_query, ["bytea", "json"])
+ plpy.execute(output_table_insert_query_prepared, [
+ serialized_weights, model_arch])
+
def create_model_summary_table(self):
if self.warm_start:
plpy.execute("DROP TABLE {0}".format(self.model_summary_table))
diff --git
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index ad14087..37a2e25 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -437,13 +437,6 @@ class FitMultipleInputValidator(FitCommonValidator):
accessible_gpus_for_seg,
self.module_name)
- if warm_start:
- mst_count = plpy.execute("SELECT count(*) FROM
{0}".format(model_selection_table))[0]['count']
- warm_count = plpy.execute("SELECT count(*) FROM
{0}".format(output_model_table))[0]['count']
-
- _assert(mst_count <= warm_count,
- "{self.module_name} error: Model table and mst table do not
match".format(self=self))
-
class MstLoaderInputValidator():
def __init__(self,
model_arch_table,
diff --git
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 3c970a5..d17ea20 100644
---
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -224,20 +224,18 @@ SELECT madlib_keras_fit_multiple_model(
FALSE, NULL, 1,
TRUE);
-
SELECT assert(
- abs(first.training_loss_final-second.training_loss_final) < 1e-6,
- 'The loss should not change for mst_key 4 since it has been removed from
mst_table')
-FROM iris_model_first_run AS first, iris_multiple_model_info AS second
-WHERE first.mst_key = second.mst_key AND second.mst_key = 4;
+ 4 NOT IN (SELECT mst_key FROM iris_multiple_model),
+ 'mst_key 4 should not be in the model table since it has been removed from
mst_table');
-INSERT INTO mst_table SELECT 4 AS mst_key, model_id, compile_params,
- 'batch_size=8, epochs=1' FROM mst_table WHERE mst_key = 1;
+SELECT assert(
+ 4 NOT IN (SELECT mst_key FROM iris_multiple_model_info),
+ 'mst_key 4 should not be in the info table since it has been removed from
mst_table');
INSERT INTO mst_table SELECT 5 AS mst_key, model_id, compile_params,
'batch_size=18, epochs=1' FROM mst_table WHERE mst_key = 1;
-SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
+SELECT madlib_keras_fit_multiple_model(
'iris_data_packed',
'iris_multiple_model',
'mst_table',
@@ -245,7 +243,15 @@ SELECT
assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
FALSE,
NULL, 1,
TRUE -- warm_start
-);$TRAP$) = 1, 'Warm start with extra mst keys should fail.');
+);
+
+SELECT assert(
+ 5 IN (SELECT mst_key FROM iris_multiple_model),
+ 'mst_key 5 should be in the model table since it has been added to
mst_table');
+
+SELECT assert(
+ 5 IN (SELECT mst_key FROM iris_multiple_model_info),
+ 'mst_key 5 should be in the info table since it has been added to
mst_table');
-- Transfer learning tests