This is an automated email from the ASF dual-hosted git repository. nkak pushed a commit to branch fix_glm_pmml in repository https://gitbox.apache.org/repos/asf/madlib.git
commit e8a7715f7425de49eb3793790dad4237d3f82664 Author: Nikhil Kak <[email protected]> AuthorDate: Mon Feb 5 16:39:40 2024 -0800 Add test for pmml comparison --- .../postgres/modules/pmml/test/pmml.setup.sql_in | 56 ++++++++++++++++++++++ .../postgres/modules/pmml/test/pmml_dt.sql_in | 20 ++++++-- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in new file mode 100644 index 00000000..8547aa0b --- /dev/null +++ b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in @@ -0,0 +1,56 @@ +/* ----------------------------------------------------------------------- *//** + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + *//* ----------------------------------------------------------------------- */ + + +-- This function will test the pmml output by comparing the results of madlib's predict function with pypmml's predict function +-- Note that pypmml needs java >=8 and <= 16 +CREATE OR REPLACE FUNCTION test_pmml_output(train_output TEXT, predict_output TEXT, test_table TEXT) returns boolean as $$ + from pypmml import Model + + train_output_summary = plpy.execute("select * from {}_summary".format(train_output)) + output_col = train_output_summary[0]["dependent_varname"].replace('"', '') + id_col = train_output_summary[0]["id_col_name"].replace('"', '') + madlib_predict_output_table = plpy.execute("SELECT * from {}".format(predict_output)) + madlib_predict_output = {} + madlib_output_col = "estimated_{}".format(output_col) + for res in madlib_predict_output_table: + madlib_predict_output[res[id_col]] = res[madlib_output_col] + + # get madlib pmml output string + pmml_query = "SELECT madlib.pmml('{}')".format(train_output) + madlib_pmml_str = plpy.execute(pmml_query)[0]["pmml"] + + # load pypmml model using madlib pmml string + pypmml_model = Model.fromString(madlib_pmml_str) + + # load data + test_data = plpy.execute("SELECT * from {}".format(test_table)) + for d in test_data: + madlib_result = madlib_predict_output[d[id_col]] + + pypmml_result = pypmml_model.predict(d) + pypmml_result = pypmml_result["predicted_{}_pmml_prediction".format(output_col)] + + if pypmml_result != madlib_result: + plpy.error("pmml comparison failed. input row: {}, madlib output: {}, pypmml output: {}".format(d, madlib_result, pypmml_result)) + return False + return True +$$ language plpython3u; \ No newline at end of file diff --git a/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in b/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in index c03d6f34..eb6b6c93 100644 --- a/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in +++ b/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in @@ -1,3 +1,11 @@ +\i m4_regexp(MADLIB_LIBRARY_PATH, + `\(.*\)/lib', + `\1/../modules/pmml/test/pmml.setup.sql_in' +) + +m4_changequote(`<!'', `!>'') + + DROP TABLE IF EXISTS dt_golf; CREATE TABLE dt_golf ( id integer NOT NULL, @@ -29,7 +37,7 @@ DROP TABLE IF EXISTS train_output, train_output_summary; SELECT tree_train('dt_golf'::text, -- source table 'train_output'::text, -- output model table 'id'::text, -- id column - 'temperature::double precision'::text, -- response + 'temperature'::text, -- response 'humidity, windy'::text, -- features NULL::text, -- exclude columns 'gini'::text, -- split criterion @@ -46,7 +54,8 @@ SELECT _print_decision_tree(tree) from train_output; -- TODO: Enable these lines after the DT tree_display bug is fixed -- SELECT tree_display('train_output', False); -SELECT pmml('train_output'); +SELECT tree_predict('train_output', 'dt_golf', 'predict_output_temperature'); +SELECT test_pmml_output('train_output', 'predict_output_temperature', 'dt_golf'); ------------------------------------------------------------------------- -- classification, grouping @@ -70,6 +79,11 @@ SELECT tree_train('dt_golf'::text, -- source table SELECT _print_decision_tree(tree) from train_output; -- SELECT tree_display('train_output', False); -SELECT pmml('train_output'); + + + + +SELECT tree_predict('train_output', 'dt_golf', 'predict_output_outlook'); +SELECT test_pmml_output('train_output', 'predict_output_outlook', 'dt_golf'); -------------------------------------------------------------------------
