This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit f06f2d8fa4442b0bf549bef8f9a9ac1c070c8e5e
Author: Nandish Jayaram <[email protected]>
AuthorDate: Thu Apr 18 14:26:15 2019 -0700

    DL: Code refactor
    
    JIRA: MADLIB-1315
    Refactor code and address comments for PR #370.
    
    Closes #370
---
 .../deep_learning/madlib_keras_helper.py_in        | 59 +---------------------
 .../deep_learning/madlib_keras_predict.py_in       | 46 +++++++++++++----
 ...ras_helper.py_in => predict_input_params.py_in} | 26 +++-------
 .../test/unit_tests/test_madlib_keras.py_in        | 17 +++++++
 4 files changed, 63 insertions(+), 85 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index d56a0e3..445b5b9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -18,12 +18,8 @@
 # under the License.
 
 import numpy as np
-import plpy
-from keras_model_arch_table import Format
-from utilities.utilities import add_postfix
-from utilities.validate_args import input_tbl_valid
 
-# Prepend 1 to np arrays using expand_dims.
+# Prepend a dimension to np arrays using expand_dims.
 def expand_input_dims(input_data, target_type=None):
     input_data = np.array(input_data)
     input_data = np.expand_dims(input_data, axis=0)
@@ -40,56 +36,3 @@ DEPENDENT_VARTYPE_COLNAME = "dependent_vartype"
 MODEL_ARCH_TABLE_COLNAME = "model_arch_table"
 MODEL_ARCH_ID_COLNAME = "model_arch_id"
 MODEL_DATA_COLNAME = "model_data"
-
-class PredictParamsProcessor:
-    def __init__(self, model_table, module_name):
-        self.module_name = module_name
-        self.model_table = model_table
-        self.model_summary_table = add_postfix(self.model_table, '_summary')
-        input_tbl_valid(self.model_summary_table, self.module_name)
-        self.model_summary_dict = self._get_model_summary_dict()
-        self.model_arch_dict = self._get_model_arch_dict()
-
-    def _get_model_summary_dict(self):
-        return plpy.execute("SELECT * FROM {0}".format(
-            self.model_summary_table))[0]
-
-    def _get_model_arch_dict(self):
-        model_arch_table = self.model_summary_dict[MODEL_ARCH_TABLE_COLNAME]
-        model_arch_id = self.model_summary_dict[MODEL_ARCH_ID_COLNAME]
-        input_tbl_valid(model_arch_table, self.module_name)
-        model_arch_query = """
-            SELECT {0}
-            FROM {1}
-            WHERE {2} = {3}
-        """.format(Format.MODEL_ARCH, model_arch_table, Format.MODEL_ID,
-                   model_arch_id)
-        query_result = plpy.execute(model_arch_query)
-        if not query_result or len(query_result) == 0:
-            plpy.error("{0}: No model arch found in table {1} with id 
{2}".format(
-                self.module_name, model_arch_table, model_arch_id))
-        return query_result[0]
-
-    def get_class_values(self):
-        return self.model_summary_dict[CLASS_VALUES_COLNAME]
-
-    def get_compile_params(self):
-        return self.model_summary_dict[COMPILE_PARAMS_COLNAME]
-
-    def get_dependent_varname(self):
-        return self.model_summary_dict[DEPENDENT_VARNAME_COLNAME]
-
-    def get_dependent_vartype(self):
-        return self.model_summary_dict[DEPENDENT_VARTYPE_COLNAME]
-
-    def get_model_arch(self):
-        return self.model_arch_dict[Format.MODEL_ARCH]
-
-    def get_model_data(self):
-        return plpy.execute("""
-                SELECT {0} FROM {1}
-            """.format(MODEL_DATA_COLNAME, self.model_table)
-                            )[0][MODEL_DATA_COLNAME]
-
-    def get_normalizing_const(self):
-        return self.model_summary_dict[NORMALIZING_CONST_COLNAME]
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 3108be5..95ae2cf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -28,9 +28,9 @@ from keras.optimizers import *
 import numpy as np
 
 from madlib_keras_helper import expand_input_dims
-from madlib_keras_helper import PredictParamsProcessor
 from madlib_keras_helper import MODEL_DATA_COLNAME
 from madlib_keras_wrapper import compile_and_set_weights
+from predict_input_params import PredictParamsProcessor
 from utilities.model_arch_info import get_input_shape
 from utilities.utilities import add_postfix
 from utilities.utilities import create_cols_from_array_sql_string
@@ -51,6 +51,41 @@ def validate_pred_type(pred_type, class_values):
             " max number of columns that can be created (1600)".format(
                 MODULE_NAME, len(class_values)+1)})
 
+def _strip_trailing_nulls_from_class_values(class_values):
+    """
+        class_values is a list of unique class levels in training data. This
+        could have multiple Nones in it, and this function strips out all the
+        Nones that occur after the first element in the list.
+        Examples:
+            1) input class_values = ['cat', 'dog']
+               output class_values = ['cat', 'dog']
+
+            2) input class_values = [None, 'cat', 'dog']
+               output class_values = [None, 'cat', 'dog']
+
+            3) input class_values = [None, 'cat', 'dog', None, None]
+               output class_values = [None, 'cat', 'dog']
+
+            4) input class_values = ['cat', 'dog', None, None]
+               output class_values = ['cat', 'dog']
+
+            5) input class_values = [None, None]
+               output class_values = [None]
+        @args:
+            @param: class_values, list
+        @returns:
+            updated class_values list
+    """
+    num_of_valid_class_values = 0
+    if class_values is not None:
+        for ele in class_values:
+            if ele is None and num_of_valid_class_values > 0:
+                break
+            num_of_valid_class_values += 1
+        # Pass only the valid class_values for creating columns
+        class_values = class_values[:num_of_valid_class_values]
+    return class_values
+
 def predict(schema_madlib, model_table, test_table, id_col,
             independent_varname, output_table, pred_type, **kwargs):
     # Refactor and add more validation as part of MADLIB-1312.
@@ -80,14 +115,7 @@ def predict(schema_madlib, model_table, test_table, id_col,
         pred_col_name = "prob"
         pred_col_type = 'double precision'
 
-    num_of_valid_class_values = 0
-    if class_values is not None:
-        for ele in class_values:
-            if ele is None and num_of_valid_class_values > 0:
-                break
-            num_of_valid_class_values += 1
-        # Pass only the valid class_values for creating columns
-        class_values = class_values[:num_of_valid_class_values]
+    class_values = _strip_trailing_nulls_from_class_values(class_values)
 
     prediction_select_clause = create_cols_from_array_sql_string(
         class_values, intermediate_col, pred_col_name,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
similarity index 81%
copy from src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
copy to src/ports/postgres/modules/deep_learning/predict_input_params.py_in
index d56a0e3..69ee961 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
@@ -17,29 +17,19 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import numpy as np
 import plpy
 from keras_model_arch_table import Format
 from utilities.utilities import add_postfix
 from utilities.validate_args import input_tbl_valid
 
-# Prepend 1 to np arrays using expand_dims.
-def expand_input_dims(input_data, target_type=None):
-    input_data = np.array(input_data)
-    input_data = np.expand_dims(input_data, axis=0)
-    if target_type:
-        input_data = input_data.astype(target_type)
-    return input_data
-
-# Name of columns in model summary table.
-CLASS_VALUES_COLNAME = "class_values"
-NORMALIZING_CONST_COLNAME = "normalizing_const"
-COMPILE_PARAMS_COLNAME = "compile_params"
-DEPENDENT_VARNAME_COLNAME = "dependent_varname"
-DEPENDENT_VARTYPE_COLNAME = "dependent_vartype"
-MODEL_ARCH_TABLE_COLNAME = "model_arch_table"
-MODEL_ARCH_ID_COLNAME = "model_arch_id"
-MODEL_DATA_COLNAME = "model_data"
+from madlib_keras_helper import CLASS_VALUES_COLNAME
+from madlib_keras_helper import COMPILE_PARAMS_COLNAME
+from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
+from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
+from madlib_keras_helper import MODEL_ARCH_ID_COLNAME
+from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
+from madlib_keras_helper import MODEL_DATA_COLNAME
+from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 
 class PredictParamsProcessor:
     def __init__(self, model_table, module_name):
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 533e347..e6b09e4 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -402,6 +402,23 @@ class MadlibKerasPredictTestCase(unittest.TestCase):
         self.subject.validate_pred_type('response', range(1598))
         self.subject.validate_pred_type('response', None)
 
+    def test_strip_trailing_nulls_from_class_values(self):
+        self.assertEqual(['cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                ['cat', 'dog']))
+        self.assertEqual([None, 'cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                [None, 'cat', 'dog']))
+        self.assertEqual([None, 'cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                [None, 'cat', 'dog', None, None]))
+        self.assertEqual(['cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                ['cat', 'dog', None, None]))
+        self.assertEqual([None],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                [None, None]))
+
 class MadlibKerasHelperTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')

Reply via email to