This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 4c97c1c376727f552a80c60217648c37d76fcb4e Author: Matthias Boehm <[email protected]> AuthorDate: Fri Apr 14 20:17:04 2023 +0200 [SYSTEMDS-3149] Additional RSS impurity measure for regression trees This patch adds an additional impurity measure, beside gini and entropy, to the decisionTree and randomForest builtin functions. The new measure is rss (residual sum of squares) for regression in order to properly learn the tree with regard to the final accuracy metrics. --- scripts/builtin/decisionTree.dml | 11 ++++- scripts/builtin/lmPredict.dml | 2 +- scripts/builtin/lmPredictStats.dml | 8 +++- scripts/builtin/randomForest.dml | 2 +- scripts/builtin/randomForestPredict.dml | 2 +- .../part1/BuiltinDecisionTreeRealDataTest.java | 16 +++++-- src/test/resources/datasets/wine/tfspec.json | 2 +- .../functions/builtin/decisionTreeRealData3.dml | 50 ++++++++++++++++++++++ 8 files changed, 82 insertions(+), 11 deletions(-) diff --git a/scripts/builtin/decisionTree.dml b/scripts/builtin/decisionTree.dml index 6a85367526..384f591b73 100644 --- a/scripts/builtin/decisionTree.dml +++ b/scripts/builtin/decisionTree.dml @@ -38,7 +38,7 @@ # candidates at tree nodes: m = ceil(num_features^max_features) # max_values Parameter controlling the number of values per feature used # as split candidates: nb = ceil(num_values^max_values) -# impurity Impurity measure: entropy, gini (default) +# impurity Impurity measure: entropy, gini (default), rss (regression) # seed Fixed seed for randomization of samples and split candidates # verbose Flag indicating verbose debug output # ------------------------------------------------------------------------------ @@ -72,7 +72,9 @@ m_decisionTree = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] cty if( max_depth > 32 ) stop("decisionTree: invalid max_depth > 32: "+max_depth); if( sum(X<=0) != 0 ) - stop("decisionTree: feature matrix X is not properly recoded/binned: "+sum(X<=0)); + stop("decisionTree: feature matrix X is not properly recoded/binned (values <= 0): "+sum(X<=0)); + if( sum(abs(X-round(X))>1e-14) != 0 ) + stop("decisionTree: feature matrix X is not properly recoded/binned (non-integer): "+sum(abs(X-round(X))>1e-14)); if( sum(y<=0) != 0 ) stop("decisionTree: label vector y is not properly recoded/binned: "+sum(y<=0)); @@ -230,6 +232,11 @@ computeImpurity = function(Matrix[Double] y2, Matrix[Double] I, String impurity) score = 1 - rowSums(f^2); # sum(f*(1-f)); else if( impurity == "entropy" ) score = rowSums(-f * log(f)); + else if( impurity == "rss" ) { # residual sum of squares + yhat = f %*% seq(1,ncol(f)); # yhat + res = outer(yhat, t(rowIndexMax(y2)), "-"); # yhat-y + score = rowSums((I * res)^2); # sum((yhat-y)^2) + } else stop("decisionTree: unsupported impurity measure: "+impurity); } diff --git a/scripts/builtin/lmPredict.dml b/scripts/builtin/lmPredict.dml index ce332a300c..6c017e2f0b 100644 --- a/scripts/builtin/lmPredict.dml +++ b/scripts/builtin/lmPredict.dml @@ -44,5 +44,5 @@ m_lmPredict = function(Matrix[Double] X, Matrix[Double] B, yhat = X %*% B[1:ncol(X),] + intercept; if( verbose ) - lmPredictStats(yhat, ytest); + lmPredictStats(yhat, ytest, TRUE); } diff --git a/scripts/builtin/lmPredictStats.dml b/scripts/builtin/lmPredictStats.dml index 986b98233c..48f4b8c9a9 100644 --- a/scripts/builtin/lmPredictStats.dml +++ b/scripts/builtin/lmPredictStats.dml @@ -26,6 +26,7 @@ # ------------------------------------------------------------------------------ # yhat column vector of predicted response values y # ytest column vector of actual response values y +# lm indicator if used for linear regression model # ------------------------------------------------------------------------------ # # OUTPUT: @@ -33,14 +34,17 @@ # R column vector holding avg_res, ss_avg_res, and R2 # ------------------------------------------------------------------------------ -m_lmPredictStats = function(Matrix[Double] yhat, Matrix[Double] ytest) +m_lmPredictStats = function(Matrix[Double] yhat, Matrix[Double] ytest, Boolean lm) return (Matrix[Double] R) { y_residual = ytest - yhat; avg_res = sum(y_residual) / nrow(ytest); ss_res = sum(y_residual^2); ss_avg_res = ss_res - nrow(ytest) * avg_res^2; - R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * (sum(ytest)/nrow(ytest))^2); + if( lm ) + R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * (sum(ytest)/nrow(ytest))^2); + else + R2 = sum((yhat - mean(ytest))^2) / sum((ytest - mean(ytest))^2) print("\nAccuracy:" + "\n--sum(ytest) = " + sum(ytest) + "\n--sum(yhat) = " + sum(yhat) + diff --git a/scripts/builtin/randomForest.dml b/scripts/builtin/randomForest.dml index fda266a3f4..7e39c9064e 100644 --- a/scripts/builtin/randomForest.dml +++ b/scripts/builtin/randomForest.dml @@ -42,7 +42,7 @@ # candidates at tree nodes: m = ceil(num_features^max_features) # max_values Parameter controlling the number of values per feature used # as split candidates: nb = ceil(num_values^max_values) -# impurity Impurity measure: entropy, gini (default) +# impurity Impurity measure: entropy, gini (default), rss (regression) # seed Fixed seed for randomization of samples and split candidates # verbose Flag indicating verbose debug output # ------------------------------------------------------------------------------ diff --git a/scripts/builtin/randomForestPredict.dml b/scripts/builtin/randomForestPredict.dml index 5923ef71f2..a003f26f7d 100644 --- a/scripts/builtin/randomForestPredict.dml +++ b/scripts/builtin/randomForestPredict.dml @@ -84,7 +84,7 @@ m_randomForestPredict = function(Matrix[Double] X, Matrix[Double] y = matrix(0,0 if( classify ) print("Accuracy (%): " + (sum(yhat == y) / nrow(y) * 100)); else - lmPredictStats(yhat, y); + lmPredictStats(yhat, y, FALSE); } if(verbose) { diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java index 808397c0d1..41bd6c651a 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java @@ -38,10 +38,9 @@ public class BuiltinDecisionTreeRealDataTest extends AutomatedTestBase { private final static String WINE_DATA = DATASET_DIR + "wine/winequality-red-white.csv"; private final static String WINE_TFSPEC = DATASET_DIR + "wine/tfspec.json"; - @Override public void setUp() { - for(int i=1; i<=2; i++) + for(int i=1; i<=3; i++) addTestConfiguration(TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); } @@ -86,9 +85,20 @@ public class BuiltinDecisionTreeRealDataTest extends AutomatedTestBase { @Test public void testRandomForestWine_MaxV1() { - //one tree with sample_frac=1 should be equivalent to decision tree runDecisionTree(2, WINE_DATA, WINE_TFSPEC, 0.989, 2, 1.0, ExecType.CP); } + + @Test + public void testDecisionTreeWineReg_MaxV1() { + //for regression we compare R2 and use rss to optimize + runDecisionTree(3, WINE_DATA, WINE_TFSPEC, 0.369, 1, 1.0, ExecType.CP); + } + + @Test + public void testRandomForestWineReg_MaxV1() { + //for regression we compare R2 and use rss to optimize + runDecisionTree(3, WINE_DATA, WINE_TFSPEC, 0.369, 2, 1.0, ExecType.CP); + } private void runDecisionTree(int test, String data, String tfspec, double minAcc, int dt, double maxV, ExecType instType) { Types.ExecMode platformOld = setExecMode(instType); diff --git a/src/test/resources/datasets/wine/tfspec.json b/src/test/resources/datasets/wine/tfspec.json index c8d573e85a..55b93fb786 100644 --- a/src/test/resources/datasets/wine/tfspec.json +++ b/src/test/resources/datasets/wine/tfspec.json @@ -12,6 +12,6 @@ {"id":8, "method":"equi-width", "numbins":10}, {"id":9, "method":"equi-width", "numbins":10}, {"id":10, "method":"equi-width", "numbins":10}, - {"id":11, "method":"equi-width", "numbins":10}, + {"id":11, "method":"equi-width", "numbins":50}, {"id":12, "method":"equi-width", "numbins":10},] } \ No newline at end of file diff --git a/src/test/scripts/functions/builtin/decisionTreeRealData3.dml b/src/test/scripts/functions/builtin/decisionTreeRealData3.dml new file mode 100644 index 0000000000..e53eb107a0 --- /dev/null +++ b/src/test/scripts/functions/builtin/decisionTreeRealData3.dml @@ -0,0 +1,50 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F = read($1, data_type="frame", format="csv", header=FALSE); +tfspec = read($2, data_type="scalar", value_type="string"); + +R = matrix("1 1 1 1 1 1 1 1 1 1 1 2 1", rows=1, cols=13) + +[X, meta] = transformencode(target=F, spec=tfspec); +Y = X[,ncol(X)-1]; +X = cbind(X[,1:ncol(X)-2], X[,ncol(X)]); +X = replace(target=X, pattern=NaN, replacement=5); # 1 val + +if( $3==1 ) { + M = decisionTree(X=X, y=Y, ctypes=R, max_features=1, max_values=$4, + impurity="rss", min_split=10, min_leaf=4, seed=7, verbose=TRUE); + yhat = decisionTreePredict(X=X, ctypes=R, M=M) +} +else { + sf = 1.0/($3-1); + M = randomForest(X=X, y=Y, ctypes=R, sample_frac=sf, num_trees=$3-1, + impurity="rss", max_features=1, max_values=$4, + min_split=10, min_leaf=4, seed=7, verbose=TRUE); + yhat = randomForestPredict(X=X, ctypes=R, M=M) +} + +jspec="{ids:true,bin:[{id:1,method:equi-width,numbins:10}]}" +yhat2 = as.matrix(transformdecode(target=yhat, spec=jspec, meta=meta[,12])); + +R = lmPredictStats(yhat2, as.matrix(F[,ncol(F)-1]), FALSE) +acc = R[3,] +write(acc, $5);
