This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch fix_glm_pmml
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 7986d9c580c9ff5d269a9d4758ad6c7325c3c406
Author: Nikhil Kak <[email protected]>
AuthorDate: Wed Feb 7 18:06:28 2024 -0800

    Fix glm pmml
---
 src/ports/postgres/modules/pmml/formula.py_in      | 22 +++++++++++++---
 src/ports/postgres/modules/pmml/pmml_builder.py_in | 22 ++++++++++++----
 .../postgres/modules/pmml/test/pmml.setup.sql_in   | 15 +++++------
 .../modules/pmml/test/pmml_glm_binomial.sql_in     | 30 ++++++++++++++++++++--
 4 files changed, 70 insertions(+), 19 deletions(-)

diff --git a/src/ports/postgres/modules/pmml/formula.py_in 
b/src/ports/postgres/modules/pmml/formula.py_in
index 4a14e0df..659de4b1 100644
--- a/src/ports/postgres/modules/pmml/formula.py_in
+++ b/src/ports/postgres/modules/pmml/formula.py_in
@@ -8,19 +8,33 @@ class Formula(object):
         self.x = self.parse(x_str)
 
     def parse(self, x_str):
-        array_expr = re.compile(r'array[[](["a-z0-9_, .]+)[]]', flags=re.I)
+        array_expr_with_intercept = re.compile(r'array[[]1, (["a-z0-9_, 
.]+)[]]', flags=re.I)
+        array_expr_without_intercept = re.compile(r'array[[](["a-z0-9_, 
.]+)[]]', flags=re.I)
         simple_col = re.compile(r'["a-z0-9_]+', flags=re.I)
         prefix = 'x'
-        if array_expr.match(x_str) is not None:
-            x_csv = array_expr.sub(r'\1', x_str)
+        if array_expr_with_intercept.match(x_str) is not None:
+            x_csv = array_expr_with_intercept.sub(r'\1',  x_str)
+            plpy.info("x_csv is ", x_csv)
             ret = [s.strip().replace('"','') for s in x_csv.split(',')]
+            plpy.info("ret is ", ret)
+            if len(ret) == self.n_coef - 1:
+                return ret
+            else:
+                pass # fall back to using 'x'
+        elif array_expr_without_intercept.match(x_str) is not None:
+            x_csv = array_expr_without_intercept.sub(r'\1',  x_str)
+            plpy.info("x_csv is ", x_csv)
+            ret = [s.strip().replace('"','') for s in x_csv.split(',')]
+            plpy.info("ret is ", ret)
             if len(ret) == self.n_coef:
                 return ret
             else:
                 pass # fall back to using 'x'
         elif simple_col.match(x_str) is not None:
             prefix = x_str.replace('"','')
-        return ["{0}[{1}]".format(prefix, str(i+1)) for i in 
range(self.n_coef)]
+        foo = ["{0}[{1}]".format(prefix, str(i+1)) for i in range(self.n_coef)]
+        # plpy.error(foo)
+        return foo
 
     def rename(self, spec):
         if isinstance(spec, str):
diff --git a/src/ports/postgres/modules/pmml/pmml_builder.py_in 
b/src/ports/postgres/modules/pmml/pmml_builder.py_in
index 125c616a..82d80677 100644
--- a/src/ports/postgres/modules/pmml/pmml_builder.py_in
+++ b/src/ports/postgres/modules/pmml/pmml_builder.py_in
@@ -209,6 +209,7 @@ class RegressionPMMLBuilder(PMMLBuilder):
     def _parse_output(self):
         self.grouped_coefs = self.output
         self.coef0 = self.output[0]['coef']
+        # plpy.error(self.coef0)
         self.n_coef = len(self.coef0)
         self.grouping_keys = [k for k in self.output[0] if k != 'coef']
 
@@ -376,15 +377,26 @@ class GLMPMMLBuilder(GeneralRegressionPMMLBuilder):
         self._build_ppmatrix()
 
         # pcells
-        pcell_attrib0 = dict(parameterName='p0', beta='0', df='1')
-        if self.function == 'classification':
-            pcell_attrib0['targetCategory'] = True
-        pcell_forest = [PCell(**pcell_attrib0)]
+        # pcell_attrib0 = dict(parameterName='p0', beta='0', df='1')
+        # TODO: do we need to include this ? probably not
+        # if self.function == 'classification':
+        #     pcell_attrib0['targetCategory'] = True
+        # pcell_forest = [PCell(**pcell_attrib0)]
+        index = 0
+        if len(coef) == len(self.formula.x):
+            pcell_attrib0 = dict(parameterName='p0', beta='0', df='1')
+            pcell_forest = [PCell(**pcell_attrib0)]
+            if self.function == 'classification':
+                pcell_attrib0['targetCategory'] = True
+            index += 1
+        else:
+            pcell_forest = []
         for i, e in enumerate(coef):
-            pcell_attrib = dict(parameterName="p"+str(i+1), beta=e, df='1')
+            pcell_attrib = dict(parameterName="p"+str(index), beta=e, df='1')
             if self.function == 'classification':
                 pcell_attrib['targetCategory'] = True
             pcell_forest.append(PCell(**pcell_attrib))
+            index += 1
 
         return GeneralRegressionModel(self.mining_schema,
                                       self.parameter_list,
diff --git a/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in 
b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in
index 8547aa0b..5c4b5c28 100644
--- a/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in
+++ b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in
@@ -22,15 +22,12 @@
 
 -- This function will test the pmml output by comparing the results of 
madlib's predict function with pypmml's predict function
 -- Note that pypmml needs java >=8 and <= 16
-CREATE OR REPLACE FUNCTION test_pmml_output(train_output TEXT, predict_output 
TEXT, test_table TEXT) returns boolean as $$
+CREATE OR REPLACE FUNCTION test_pmml_output(train_output TEXT, test_table 
TEXT, predict_output_table TEXT,
+       madlib_output_col TEXT, id_col TEXT, pypmml_col TEXT) returns boolean 
as $$
     from pypmml import Model
 
-    train_output_summary = plpy.execute("select * from 
{}_summary".format(train_output))
-    output_col = train_output_summary[0]["dependent_varname"].replace('"', '')
-    id_col = train_output_summary[0]["id_col_name"].replace('"', '')
-    madlib_predict_output_table = plpy.execute("SELECT * from 
{}".format(predict_output))
+    madlib_predict_output_table = plpy.execute("SELECT * from 
{}".format(predict_output_table))
     madlib_predict_output = {}
-    madlib_output_col = "estimated_{}".format(output_col)
     for res in madlib_predict_output_table:
         madlib_predict_output[res[id_col]] = res[madlib_output_col]
 
@@ -47,8 +44,10 @@ CREATE OR REPLACE FUNCTION test_pmml_output(train_output 
TEXT, predict_output TE
         madlib_result = madlib_predict_output[d[id_col]]
 
         pypmml_result = pypmml_model.predict(d)
-        pypmml_result = 
pypmml_result["predicted_{}_pmml_prediction".format(output_col)]
-
+        plpy.info(pypmml_result)
+        if pypmml_col not in pypmml_result:
+            plpy.error("Invalid pypmml column {}".format(pypmml_col))
+        pypmml_result = pypmml_result[pypmml_col]
         if pypmml_result != madlib_result:
             plpy.error("pmml comparison failed. input row: {}, madlib output: 
{}, pypmml output: {}".format(d, madlib_result, pypmml_result))
             return False
diff --git a/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in 
b/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in
index b874bd9e..695f6cfe 100644
--- a/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in
+++ b/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in
@@ -85,6 +85,11 @@ SELECT glm(
     'family=binomial, link=probit', NULL, 'max_iter=1000, tolerance=1e-16'
 );
 
+CREATE TABLE glm_predict_probit_out as SELECT glm_predict(coef, ARRAY[1, 
length, diameter, height, whole, shucked, viscera, shell]::float8[], 'probit')
+FROM abalone_probit_out, abalone;
+
+SELECT test_pmml_output('abalone_probit_out', 'abalone', 
'glm_predict_probit_out');
+
 DROP TABLE IF EXISTS abalone_logit_out, abalone_logit_out_summary;
 SELECT glm(
     'abalone',
@@ -94,7 +99,28 @@ SELECT glm(
     'family=binomial, link=logit', NULL, 'max_iter=1000, tolerance=1e-16'
 );
 
+CREATE TABLE glm_predict_logit_out as SELECT glm_predict(coef, ARRAY[1, 
length, diameter, height, whole, shucked, viscera, shell]::float8[], 'logit')
+FROM abalone_logit_out, abalone;
+
+
+-- SELECT pmml('abalone_probit_out');
+-- SELECT pmml('abalone_logit_out');
+
+
+
+-- TODO
+-- with 1 intercept + with 0 intercept + with no intercept
+--      X
+-- different args like family, link etc
+--      X
+-- categorical + continuous dependent var
+--      X
+-- Does grouping work with pmml ?
+--      X
+-- special chars ?
+-- also test the names of the columns, not just the values. Might be a good 
idea to compare the pmml output with an expected value
+
+-- SELECT tree_predict('train_output', 'abalone', 'predict_output_outlook');
+-- SELECT test_pmml_output('train_output', 'predict_output_outlook', 
'abalone');
 
-SELECT pmml('abalone_probit_out');
-SELECT pmml('abalone_logit_out');
 

Reply via email to