This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch fix_glm_pmml
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit e8a7715f7425de49eb3793790dad4237d3f82664
Author: Nikhil Kak <[email protected]>
AuthorDate: Mon Feb 5 16:39:40 2024 -0800

    Add test for pmml comparison
---
 .../postgres/modules/pmml/test/pmml.setup.sql_in   | 56 ++++++++++++++++++++++
 .../postgres/modules/pmml/test/pmml_dt.sql_in      | 20 ++++++--
 2 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in 
b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in
new file mode 100644
index 00000000..8547aa0b
--- /dev/null
+++ b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in
@@ -0,0 +1,56 @@
+/* ----------------------------------------------------------------------- 
*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ----------------------------------------------------------------------- 
*/
+
+
+-- This function will test the pmml output by comparing the results of 
madlib's predict function with pypmml's predict function
+-- Note that pypmml needs java >=8 and <= 16
+CREATE OR REPLACE FUNCTION test_pmml_output(train_output TEXT, predict_output 
TEXT, test_table TEXT) returns boolean as $$
+    from pypmml import Model
+
+    train_output_summary = plpy.execute("select * from 
{}_summary".format(train_output))
+    output_col = train_output_summary[0]["dependent_varname"].replace('"', '')
+    id_col = train_output_summary[0]["id_col_name"].replace('"', '')
+    madlib_predict_output_table = plpy.execute("SELECT * from 
{}".format(predict_output))
+    madlib_predict_output = {}
+    madlib_output_col = "estimated_{}".format(output_col)
+    for res in madlib_predict_output_table:
+        madlib_predict_output[res[id_col]] = res[madlib_output_col]
+
+    # get madlib pmml output string
+    pmml_query = "SELECT madlib.pmml('{}')".format(train_output)
+    madlib_pmml_str = plpy.execute(pmml_query)[0]["pmml"]
+
+    # load pypmml model using madlib pmml string
+    pypmml_model = Model.fromString(madlib_pmml_str)
+
+    # load data
+    test_data = plpy.execute("SELECT * from {}".format(test_table))
+    for d in test_data:
+        madlib_result = madlib_predict_output[d[id_col]]
+
+        pypmml_result = pypmml_model.predict(d)
+        pypmml_result = 
pypmml_result["predicted_{}_pmml_prediction".format(output_col)]
+
+        if pypmml_result != madlib_result:
+            plpy.error("pmml comparison failed. input row: {}, madlib output: 
{}, pypmml output: {}".format(d, madlib_result, pypmml_result))
+            return False
+    return True
+$$ language plpython3u;
\ No newline at end of file
diff --git a/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in 
b/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in
index c03d6f34..eb6b6c93 100644
--- a/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in
+++ b/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in
@@ -1,3 +1,11 @@
+\i m4_regexp(MADLIB_LIBRARY_PATH,
+             `\(.*\)/lib',
+              `\1/../modules/pmml/test/pmml.setup.sql_in'
+)
+
+m4_changequote(`<!'', `!>'')
+
+
 DROP TABLE IF EXISTS dt_golf;
 CREATE TABLE dt_golf (
     id integer NOT NULL,
@@ -29,7 +37,7 @@ DROP TABLE IF EXISTS train_output, train_output_summary;
 SELECT tree_train('dt_golf'::text,         -- source table
                          'train_output'::text,    -- output model table
                          'id'::text,              -- id column
-                         'temperature::double precision'::text,           -- 
response
+                         'temperature'::text,           -- response
                          'humidity, windy'::text,   -- features
                          NULL::text,        -- exclude columns
                          'gini'::text,      -- split criterion
@@ -46,7 +54,8 @@ SELECT _print_decision_tree(tree) from train_output;
 -- TODO: Enable these lines after the DT tree_display bug is fixed
 -- SELECT tree_display('train_output', False);
 
-SELECT pmml('train_output');
+SELECT tree_predict('train_output', 'dt_golf', 'predict_output_temperature');
+SELECT test_pmml_output('train_output', 'predict_output_temperature', 
'dt_golf');
 -------------------------------------------------------------------------
 
 -- classification, grouping
@@ -70,6 +79,11 @@ SELECT tree_train('dt_golf'::text,         -- source table
 SELECT _print_decision_tree(tree) from train_output;
 -- SELECT tree_display('train_output', False);
 
-SELECT pmml('train_output');
+
+
+
+
+SELECT tree_predict('train_output', 'dt_golf', 'predict_output_outlook');
+SELECT test_pmml_output('train_output', 'predict_output_outlook', 'dt_golf');
 -------------------------------------------------------------------------
 

Reply via email to