This is an automated email from the ASF dual-hosted git repository. khannaekta pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 7e4443bcee31a58a4dd4e9aea5b091bbdaa06759 Author: Ekta Khanna <[email protected]> AuthorDate: Wed Mar 11 13:41:32 2020 -0700 DL: Fix fit multiple to create summary tables with class_values including NULL Prior to this commit, since the python array would convert SQL NULL values to None, creating summary table with such class_values array would fail. This commit fixes that issue. --- .../madlib_keras_fit_multiple_model.py_in | 12 +++++------- .../test/madlib_keras_model_selection.sql_in | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in index 93a86f9..a32421b 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in @@ -388,6 +388,7 @@ class FitMultipleModel(): plpy.execute("DROP TABLE {0}".format(self.model_summary_table)) src_summary_dict = get_source_summary_table_dict(self.fit_validator_train) class_values = src_summary_dict['class_values'] + class_values_type = src_summary_dict['class_values_type'] dep_vartype = src_summary_dict['dep_vartype'] dependent_varname = \ src_summary_dict['dependent_varname_in_source_table'] @@ -397,11 +398,8 @@ class FitMultipleModel(): self.validation_table = 'NULL' if self.validation_table is None \ else '$MAD${0}$MAD$'.format(self.validation_table) if class_values is None: - class_values_str = 'NULL::{0}'.format(src_summary_dict['class_values_type']) num_classes = 'NULL' else: - class_values_str = 'ARRAY{0}::{1}'.format(class_values, - src_summary_dict['class_values_type']) num_classes = len(class_values) name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name) descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description) @@ -410,7 +408,7 @@ class FitMultipleModel(): dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME normalizing_const_colname = NORMALIZING_CONST_COLNAME float32_sql_type = FLOAT32_SQL_TYPE - update_query = """ + create_query = plpy.prepare(""" CREATE TABLE {self.model_summary_table} AS SELECT $MAD${self.source_table}$MAD$::TEXT AS source_table, @@ -429,12 +427,12 @@ class FitMultipleModel(): '{self.end_training_time}'::TIMESTAMP AS end_training_time, '{self.version}'::TEXT AS madlib_version, {num_classes}::INTEGER AS num_classes, - {class_values_str} AS {class_values_colname}, + $1 AS {class_values_colname}, $MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname}, {norm_const}::{float32_sql_type} AS {normalizing_const_colname}, ARRAY{metrics_iters}::INTEGER[] AS metrics_iters - """.format(**locals()) - plpy.execute(update_query) + """.format(**locals()), [class_values_type]) + plpy.execute(create_query, [class_values]) def update_info_table(self, mst, is_train): mst_key = mst[self.mst_key_col] diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in index 26c1a34..ddf2e0f 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in @@ -367,4 +367,25 @@ SELECT assert(cnt = 1, FROM (SELECT count(*) cnt FROM iris_multiple_model_info WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info; + +-- Test when class values have NULL values +UPDATE iris_data_packed_summary SET class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL]; +DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info; +SELECT madlib_keras_fit_multiple_model( + 'iris_data_packed', + 'iris_multiple_model', + 'mst_table_1row', + 1, + FALSE, + NULL, + 1, + FALSE +); + +SELECT assert( + num_classes = 3 AND + class_values = '{Iris-setosa,Iris-versicolor,NULL}', + 'Keras Fit Multiple num_clases and class values Validation failed. Actual:' || __to_char(summary)) +FROM (SELECT * FROM iris_multiple_model_summary) summary; + !>)
