This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a3348e2425 [SYSTEMDS-3506] Add missing randomForestPredict built-in
function
a3348e2425 is described below
commit a3348e2425b0a2fff87898392594ed9d34fae2ef
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Mar 14 14:13:47 2023 +0100
[SYSTEMDS-3506] Add missing randomForestPredict built-in function
This patch adds the missing randomForestPreduct builtin function, which
internally calls the decisionTreePredict and does feature sampling,
label aggregation (majority voting, average) and summary statistics.
Unfortunately, the decisionTreePredict still fails when ran in a
parfor context on which we will follow-up separately.
---
scripts/builtin/randomForestPredict.dml | 108 +++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../scripts/functions/builtin/randomForestTest.dml | 1 +
3 files changed, 110 insertions(+)
diff --git a/scripts/builtin/randomForestPredict.dml
b/scripts/builtin/randomForestPredict.dml
new file mode 100644
index 0000000000..af2ec157ef
--- /dev/null
+++ b/scripts/builtin/randomForestPredict.dml
@@ -0,0 +1,108 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This script implements random forest prediction for recoded and binned
+# categorical and numerical input features.
+#
+# INPUT:
+#
------------------------------------------------------------------------------
+# X Feature matrix in recoded/binned representation
+# y Label matrix in recoded/binned representation,
+# optional for accuracy evaluation
+# ctypes Row-Vector of column types [1 scale/ordinal, 2 categorical]
+# M Matrix M holding the learned trees (one tree per row),
+# see randomForest() for the detailed tree representation.
+# verbose Flag indicating verbose debug output
+#
------------------------------------------------------------------------------
+#
+# OUTPUT:
+#
------------------------------------------------------------------------------
+# yhat Label vector of predictions
+#
------------------------------------------------------------------------------
+
+m_randomForestPredict = function(Matrix[Double] X, Matrix[Double] y =
matrix(0,0,0),
+ Matrix[Double] ctypes, Matrix[Double] M, Boolean verbose = FALSE)
+ return(Matrix[Double] yhat)
+{
+ t1 = time();
+ classify = FALSE; # TODO as.scalar(ctypes[1,ncol(X)+1]) == 2;
+ yExists = (nrow(X)==nrow(y));
+
+ if(verbose) {
+ print("randomForestPredict: called for batch of "+nrow(X)+" rows, model of
"
+ +nrow(M)+" trees, and with labels-provided "+yExists+".");
+ }
+
+ # scoring of num_tree decision trees
+ Ytmp = matrix(0, rows=nrow(M), cols=nrow(X));
+ # TODO parfor issue with decisionTreePredict
+ for(i in 1:nrow(M)) {
+ if( verbose )
+ print("randomForest: start scoring tree "+i+"/"+nrow(M)+".");
+
+ # step 1: sample features (consistent with training)
+ I2 = M[i, 1:ncol(X)];
+ Xi = removeEmpty(target=X, margin="cols", select=I2);
+
+ # step 2: score decision tree
+ t2 = time();
+ ret = decisionTreePredict(X=Xi, M=M[i,], strategy="TT");
+ Ytmp[i,1:nrow(ret)] = t(ret);
+ if( verbose )
+ print("-- ["+i+"] scored decision tree in "+(time()-t2)/1e9+" seconds.");
+ }
+
+ # label aggregation (majority voting / average)
+ yhat = matrix(0, nrow(X), 1);
+ if( classify ) {
+ parfor(i in 1:nrow(X))
+ yhat[i,1] = rowIndexMax(t(table(Ytmp[,i],1)));
+ }
+ else {
+ yhat = t(colSums(Ytmp)/nrow(M));
+ }
+
+ # summary statistics
+ if( yExists & verbose ) {
+ if( classify ) {
+ accuracy = sum(yhat == y) / nrow(y) * 100;
+ print("Accuracy (%): " + accuracy);
+ }
+ else {
+ # TODO eliminate redundancy with lmPredict
+ y_residual = y - yhat;
+ avg_res = sum(y_residual) / nrow(y);
+ ss_res = sum(y_residual^2);
+ ss_avg_res = ss_res - nrow(y) * avg_res^2;
+ R2 = 1 - ss_res / (sum(y^2) - nrow(y) * (sum(y)/nrow(y))^2);
+ print("\nAccuracy:" +
+ "\n--sum(y) = " + sum(y) +
+ "\n--sum(yhat) = " + sum(yhat) +
+ "\n--AVG_RES_Y: " + avg_res +
+ "\n--SS_AVG_RES_Y: " + ss_avg_res +
+ "\n--R2: " + R2 );
+ }
+ }
+
+ if(verbose) {
+ print("randomForest: scored batch of "+nrow(X)+" rows with "+nrow(M)+"
trees in "+(time()-t1)/1e9+" seconds.");
+ }
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index e627adb286..068968eb87 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -243,6 +243,7 @@ public enum Builtins {
QR("qr", false, ReturnType.MULTI_RETURN),
QUANTILE("quantile", false),
RANDOM_FOREST("randomForest", true),
+ RANDOM_FOREST_PREDICT("randomForestPredict", true),
RANGE("range", false),
RBIND("rbind", false),
REMOVE("remove", false, ReturnType.MULTI_RETURN),
diff --git a/src/test/scripts/functions/builtin/randomForestTest.dml
b/src/test/scripts/functions/builtin/randomForestTest.dml
index 092d52f8cb..25fa54eea8 100644
--- a/src/test/scripts/functions/builtin/randomForestTest.dml
+++ b/src/test/scripts/functions/builtin/randomForestTest.dml
@@ -39,5 +39,6 @@ jspec = "{ids: true, bin: ["
R = matrix(1, rows=1, cols=ncol(X));
M = randomForest(X=X, y=Y, ctypes=R, num_trees=num_trees, seed=7,
max_depth=depth, min_leaf=num_leafs, impurity=impurity, verbose=TRUE);
+randomForestPredict(X=X, y=Y, ctypes=R, M=M, verbose=TRUE);
write(M, $7);