This is an automated email from the ASF dual-hosted git repository. nkak pushed a commit to branch fix_glm_pmml in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 7986d9c580c9ff5d269a9d4758ad6c7325c3c406 Author: Nikhil Kak <[email protected]> AuthorDate: Wed Feb 7 18:06:28 2024 -0800 Fix glm pmml --- src/ports/postgres/modules/pmml/formula.py_in | 22 +++++++++++++--- src/ports/postgres/modules/pmml/pmml_builder.py_in | 22 ++++++++++++---- .../postgres/modules/pmml/test/pmml.setup.sql_in | 15 +++++------ .../modules/pmml/test/pmml_glm_binomial.sql_in | 30 ++++++++++++++++++++-- 4 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/ports/postgres/modules/pmml/formula.py_in b/src/ports/postgres/modules/pmml/formula.py_in index 4a14e0df..659de4b1 100644 --- a/src/ports/postgres/modules/pmml/formula.py_in +++ b/src/ports/postgres/modules/pmml/formula.py_in @@ -8,19 +8,33 @@ class Formula(object): self.x = self.parse(x_str) def parse(self, x_str): - array_expr = re.compile(r'array[[](["a-z0-9_, .]+)[]]', flags=re.I) + array_expr_with_intercept = re.compile(r'array[[]1, (["a-z0-9_, .]+)[]]', flags=re.I) + array_expr_without_intercept = re.compile(r'array[[](["a-z0-9_, .]+)[]]', flags=re.I) simple_col = re.compile(r'["a-z0-9_]+', flags=re.I) prefix = 'x' - if array_expr.match(x_str) is not None: - x_csv = array_expr.sub(r'\1', x_str) + if array_expr_with_intercept.match(x_str) is not None: + x_csv = array_expr_with_intercept.sub(r'\1', x_str) + plpy.info("x_csv is ", x_csv) ret = [s.strip().replace('"','') for s in x_csv.split(',')] + plpy.info("ret is ", ret) + if len(ret) == self.n_coef - 1: + return ret + else: + pass # fall back to using 'x' + elif array_expr_without_intercept.match(x_str) is not None: + x_csv = array_expr_without_intercept.sub(r'\1', x_str) + plpy.info("x_csv is ", x_csv) + ret = [s.strip().replace('"','') for s in x_csv.split(',')] + plpy.info("ret is ", ret) if len(ret) == self.n_coef: return ret else: pass # fall back to using 'x' elif simple_col.match(x_str) is not None: prefix = x_str.replace('"','') - return ["{0}[{1}]".format(prefix, str(i+1)) for i in range(self.n_coef)] + foo = ["{0}[{1}]".format(prefix, str(i+1)) for i in range(self.n_coef)] + # plpy.error(foo) + return foo def rename(self, spec): if isinstance(spec, str): diff --git a/src/ports/postgres/modules/pmml/pmml_builder.py_in b/src/ports/postgres/modules/pmml/pmml_builder.py_in index 125c616a..82d80677 100644 --- a/src/ports/postgres/modules/pmml/pmml_builder.py_in +++ b/src/ports/postgres/modules/pmml/pmml_builder.py_in @@ -209,6 +209,7 @@ class RegressionPMMLBuilder(PMMLBuilder): def _parse_output(self): self.grouped_coefs = self.output self.coef0 = self.output[0]['coef'] + # plpy.error(self.coef0) self.n_coef = len(self.coef0) self.grouping_keys = [k for k in self.output[0] if k != 'coef'] @@ -376,15 +377,26 @@ class GLMPMMLBuilder(GeneralRegressionPMMLBuilder): self._build_ppmatrix() # pcells - pcell_attrib0 = dict(parameterName='p0', beta='0', df='1') - if self.function == 'classification': - pcell_attrib0['targetCategory'] = True - pcell_forest = [PCell(**pcell_attrib0)] + # pcell_attrib0 = dict(parameterName='p0', beta='0', df='1') + # TODO: do we need to include this ? probably not + # if self.function == 'classification': + # pcell_attrib0['targetCategory'] = True + # pcell_forest = [PCell(**pcell_attrib0)] + index = 0 + if len(coef) == len(self.formula.x): + pcell_attrib0 = dict(parameterName='p0', beta='0', df='1') + pcell_forest = [PCell(**pcell_attrib0)] + if self.function == 'classification': + pcell_attrib0['targetCategory'] = True + index += 1 + else: + pcell_forest = [] for i, e in enumerate(coef): - pcell_attrib = dict(parameterName="p"+str(i+1), beta=e, df='1') + pcell_attrib = dict(parameterName="p"+str(index), beta=e, df='1') if self.function == 'classification': pcell_attrib['targetCategory'] = True pcell_forest.append(PCell(**pcell_attrib)) + index += 1 return GeneralRegressionModel(self.mining_schema, self.parameter_list, diff --git a/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in index 8547aa0b..5c4b5c28 100644 --- a/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in +++ b/src/ports/postgres/modules/pmml/test/pmml.setup.sql_in @@ -22,15 +22,12 @@ -- This function will test the pmml output by comparing the results of madlib's predict function with pypmml's predict function -- Note that pypmml needs java >=8 and <= 16 -CREATE OR REPLACE FUNCTION test_pmml_output(train_output TEXT, predict_output TEXT, test_table TEXT) returns boolean as $$ +CREATE OR REPLACE FUNCTION test_pmml_output(train_output TEXT, test_table TEXT, predict_output_table TEXT, + madlib_output_col TEXT, id_col TEXT, pypmml_col TEXT) returns boolean as $$ from pypmml import Model - train_output_summary = plpy.execute("select * from {}_summary".format(train_output)) - output_col = train_output_summary[0]["dependent_varname"].replace('"', '') - id_col = train_output_summary[0]["id_col_name"].replace('"', '') - madlib_predict_output_table = plpy.execute("SELECT * from {}".format(predict_output)) + madlib_predict_output_table = plpy.execute("SELECT * from {}".format(predict_output_table)) madlib_predict_output = {} - madlib_output_col = "estimated_{}".format(output_col) for res in madlib_predict_output_table: madlib_predict_output[res[id_col]] = res[madlib_output_col] @@ -47,8 +44,10 @@ CREATE OR REPLACE FUNCTION test_pmml_output(train_output TEXT, predict_output TE madlib_result = madlib_predict_output[d[id_col]] pypmml_result = pypmml_model.predict(d) - pypmml_result = pypmml_result["predicted_{}_pmml_prediction".format(output_col)] - + plpy.info(pypmml_result) + if pypmml_col not in pypmml_result: + plpy.error("Invalid pypmml column {}".format(pypmml_col)) + pypmml_result = pypmml_result[pypmml_col] if pypmml_result != madlib_result: plpy.error("pmml comparison failed. input row: {}, madlib output: {}, pypmml output: {}".format(d, madlib_result, pypmml_result)) return False diff --git a/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in b/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in index b874bd9e..695f6cfe 100644 --- a/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in +++ b/src/ports/postgres/modules/pmml/test/pmml_glm_binomial.sql_in @@ -85,6 +85,11 @@ SELECT glm( 'family=binomial, link=probit', NULL, 'max_iter=1000, tolerance=1e-16' ); +CREATE TABLE glm_predict_probit_out as SELECT glm_predict(coef, ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]::float8[], 'probit') +FROM abalone_probit_out, abalone; + +SELECT test_pmml_output('abalone_probit_out', 'abalone', 'glm_predict_probit_out'); + DROP TABLE IF EXISTS abalone_logit_out, abalone_logit_out_summary; SELECT glm( 'abalone', @@ -94,7 +99,28 @@ SELECT glm( 'family=binomial, link=logit', NULL, 'max_iter=1000, tolerance=1e-16' ); +CREATE TABLE glm_predict_logit_out as SELECT glm_predict(coef, ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]::float8[], 'logit') +FROM abalone_logit_out, abalone; + + +-- SELECT pmml('abalone_probit_out'); +-- SELECT pmml('abalone_logit_out'); + + + +-- TODO +-- with 1 intercept + with 0 intercept + with no intercept +-- X +-- different args like family, link etc +-- X +-- categorical + continuous dependent var +-- X +-- Does grouping work with pmml ? +-- X +-- special chars ? +-- also test the names of the columns, not just the values. Might be a good idea to compare the pmml output with an expected value + +-- SELECT tree_predict('train_output', 'abalone', 'predict_output_outlook'); +-- SELECT test_pmml_output('train_output', 'predict_output_outlook', 'abalone'); -SELECT pmml('abalone_probit_out'); -SELECT pmml('abalone_logit_out');
