reductionista commented on a change in pull request #355: Keras fit interface
URL: https://github.com/apache/madlib/pull/355#discussion_r267108455
 
 

 ##########
 File path: src/ports/postgres/modules/convex/madlib_keras.py_in
 ##########
 @@ -0,0 +1,633 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import os
+import plpy
+import time
+
+from keras import backend as K
+from keras import utils as keras_utils
+from keras.layers import *
+from keras.models import *
+from keras.optimizers import *
+from keras.regularizers import *
+import numpy as np
+
+from madlib_keras_helper import *
+from utilities.model_arch_info import get_input_shape
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import output_tbl_valid
+from utilities.utilities import _assert
+from utilities.utilities import add_postfix
+from utilities.utilities import is_var_valid
+from utilities.utilities import madlib_version
+
+def _validate_input_table(source_table, independent_varname,
+                          dependent_varname):
+    _assert(is_var_valid(source_table, independent_varname),
+            "model_keras error: invalid independent_varname "
+            "('{independent_varname}') for source_table "
+            "({source_table})!".format(
+                independent_varname=independent_varname,
+                source_table=source_table))
+
+    _assert(is_var_valid(source_table, dependent_varname),
+            "model_keras error: invalid dependent_varname "
+            "('{dependent_varname}') for source_table "
+            "({source_table})!".format(
+                dependent_varname=dependent_varname, 
source_table=source_table))
+
+def _validate_input_args(
+    source_table, dependent_varname, independent_varname, model_arch_table,
+    validation_table, output_model_table, num_iterations):
+
+    module_name = 'model_keras'
+    _assert(num_iterations > 0,
+        "model_keras error: Number of iterations cannot be < 1.")
+
+    output_summary_model_table = add_postfix(output_model_table, "_summary")
+    input_tbl_valid(source_table, module_name)
+    # Source table and validation tables must have the same schema
+    _validate_input_table(source_table, independent_varname, dependent_varname)
+    if validation_table and validation_table.strip() != '':
+        input_tbl_valid(validation_table, module_name)
+        _validate_input_table(validation_table, independent_varname,
+                              dependent_varname)
+    # Validate model arch table's schema.
+    input_tbl_valid(model_arch_table, module_name)
+    # Validate output tables
+    output_tbl_valid(output_model_table, module_name)
+    output_tbl_valid(output_summary_model_table, module_name)
+
+def _validate_input_shapes(source_table, independent_varname, input_shape):
+    """
+    Validate if the input shape specified in model architecture is the same
+    as the shape of the image specified in the indepedent var of the input
+    table.
+    """
+    # The weird indexing with 'i+2' and 'i' below has two reasons:
+    # 1) The indexing for array_upper() starts from 1, but indexing in the
+    # input_shape list starts from 0.
+    # 2) Input_shape is only the image's dimension, whereas a row of
+    # independent varname in a table contains buffer size as the first
+    # dimension, followed by the image's dimension. So we must ignore
+    # the first dimension from independent varname.
+    array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
+        independent_varname, i+2, i) for i in range(len(input_shape)))
+    query = """
+        SELECT {0}
+        FROM {1}
+        LIMIT 1
+    """.format(array_upper_query, source_table)
+    # This query will fail if an image in independent var does not have the
+    # same number of dimensions as the input_shape.
+    result = plpy.execute(query)[0]
+    _assert(len(result) == len(input_shape),
+        "model_keras error: The number of dimensions ({0}) of each image in" \
+        " model architecture and {1} in {2} ({3}) do not match.".format(
+            len(input_shape), independent_varname, source_table, len(result)))
+    for i in range(len(input_shape)):
+        key_name = "n_{0}".format(i)
+        if result[key_name] != input_shape[i]:
+            # Construct the shape in independent varname to display meaningful
+            # error msg.
+            input_shape_from_table = [result["n_{0}".format(i)] for i in range(
+                1, len(input_shape))]
+            plpy.error("model_keras error: Input shape {0} in the model" \
+                " architecture does not match the input shape {1} of column" \
+                " {2} in table {3}.".format(input_shape, 
input_shape_from_table, independent_varname, source_table))
+
+def fit(schema_madlib, source_table, model, dependent_varname,
+        independent_varname, model_arch_table, model_arch_id, compile_params,
+        fit_params, num_iterations, num_classes, use_gpu = True,
+        validation_table=None, name="", description="", **kwargs):
+    _validate_input_args(source_table, dependent_varname, independent_varname,
+                         model_arch_table, validation_table,
+                         model, num_iterations)
+
+    start_training_time = datetime.datetime.now()
+
+    # Disable GPU on master
+    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
+
+    use_gpu = bool(use_gpu)
+
+    # Get the serialized master model
+    start_deserialization = time.time()
+    model_arch_query = "SELECT model_arch, model_weights FROM {0} WHERE id = 
{1}".format(model_arch_table, model_arch_id)
+    query_result = plpy.execute(model_arch_query)
+    if not  query_result or len(query_result) == 0:
+        plpy.error("no model arch found in table {0} with id 
{1}".format(model_arch_table, model_arch_id))
+    query_result = query_result[0]
+    model_arch = query_result['model_arch']
+    input_shape = get_input_shape(model_arch)
+    _validate_input_shapes(source_table, independent_varname, input_shape)
+    if validation_table:
+        _validate_input_shapes(
+            validation_table, independent_varname, input_shape)
+    model_weights_serialized = query_result['model_weights']
+
+    # Convert model from json and initialize weights
+    master_model = model_from_json(model_arch)
+    model_weights = master_model.get_weights()
+
+    # Get shape of weights in each layer from model arch
+    model_shapes = []
+    for weight_arr in master_model.get_weights():
+        model_shapes.append(weight_arr.shape)
+
+    if model_weights_serialized:
+        # If warm start from previously trained model, set weights
+        model_weights = deserialize_weights_orig(model_weights_serialized, 
model_shapes)
+        master_model.set_weights(model_weights)
+
+    end_deserialization = time.time()
+    # plpy.info("Model deserialization time: {} 
sec".format(end_deserialization - start_deserialization))
+
+    # Construct validation dataset if provided
+    validation_set_provided = bool(validation_table)
+    validation_aggregate_accuracy = []; validation_aggregate_loss = []
+    x_validation = None; y_validation = None
+    if validation_set_provided:
+        x_validation,  y_validation = get_data_as_np_array(validation_table,
+                                                           dependent_varname,
+                                                           independent_varname,
+                                                           input_shape,
+                                                           num_classes)
+
+    # Compute total buffers on each segment
+    total_buffers_per_seg = plpy.execute(
+        """ SELECT gp_segment_id, count(*) AS total_buffers_per_seg
+            FROM {0}
+            GROUP BY gp_segment_id
+        """.format(source_table))
+    seg_nums = [int(each_buffer["gp_segment_id"]) for each_buffer in 
total_buffers_per_seg]
+    total_buffers_per_seg = [int(each_buffer["total_buffers_per_seg"]) for 
each_buffer in total_buffers_per_seg]
+
+    # Prepare the SQL for running distributed training via UDA
+    compile_params_to_pass = "$madlib$" + compile_params + "$madlib$"
+    fit_params_to_pass = "$madlib$" + fit_params + "$madlib$"
 
 Review comment:
   This allows the user to inject arbitrary SQL commands, by including the 
string `$madlib$` in the compile or fit params.  This isn't as bad as allowing 
them to execute arbitrary python code, but it's something that could also 
happen unintentionally and result in strange errors.  We could reduce the risk 
by choosing a random string, but see below for a better way of handling this.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to