This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 75e7e64f22 [SYSTEMDS-3149] Fix misc issues decisionTree/randomForest
training
75e7e64f22 is described below
commit 75e7e64f228cccfe71017199799298e227e4bd23
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Apr 11 21:31:27 2023 +0200
[SYSTEMDS-3149] Fix misc issues decisionTree/randomForest training
This patch fixes various issues in the new decisionTree and randomForest
built-in functions as well as adds new and stricter tests:
* randomForest validation checks and parameters (consistent to DT)
* randomForest correct feature map with feature_frac=1.0
* decisionTree simplification of leaf label computation
* synchronized deep copy of hop-DAGs to avoid race conditions in parfor
* added missing size propagation on spark rev operations
* new tests with randomForest that check equivalent results to DT
with num_tree=1 and reasonable results with larger ensembles
---
scripts/builtin/decisionTree.dml | 4 ++--
scripts/builtin/randomForest.dml | 20 +++++++++++++-------
.../instructions/spark/ReorgSPInstruction.java | 5 ++++-
.../apache/sysds/runtime/util/ProgramConverter.java | 5 ++++-
.../part1/BuiltinDecisionTreeRealDataTest.java | 20 +++++++++++++++++---
.../functions/builtin/decisionTreeRealData.dml | 17 +++++++++++++----
6 files changed, 53 insertions(+), 18 deletions(-)
diff --git a/scripts/builtin/decisionTree.dml b/scripts/builtin/decisionTree.dml
index 5e72127dd1..4d4e273c65 100644
--- a/scripts/builtin/decisionTree.dml
+++ b/scripts/builtin/decisionTree.dml
@@ -226,8 +226,8 @@ computeLeafLabel = function(Matrix[Double] y2,
Matrix[Double] I, Boolean classif
return(Double label)
{
f = (I %*% y2) / sum(I);
- label = ifelse(classify,
- as.scalar(rowIndexMax(f)), sum(t(f)*seq(1,ncol(f))));
+ label = as.scalar(ifelse(classify,
+ rowIndexMax(f), f %*% seq(1,ncol(f))));
if(verbose)
print("-- leaf node label: " + label +" ("+sum(I)*max(f)+"/"+sum(I)+")");
}
diff --git a/scripts/builtin/randomForest.dml b/scripts/builtin/randomForest.dml
index 176628e3a6..37f7f64fb4 100644
--- a/scripts/builtin/randomForest.dml
+++ b/scripts/builtin/randomForest.dml
@@ -37,6 +37,7 @@
# feature_frac Sample fraction of features for each tree in the forest
# max_depth Maximum depth of the learned tree (stopping criterion)
# min_leaf Minimum number of samples in leaf nodes (stopping criterion)
+# min_split Minimum number of samples in leaf for attempting a split
# max_features Parameter controlling the number of features used as split
# candidates at tree nodes: m = ceil(num_features^max_features)
# impurity Impurity measure: entropy, gini (default)
@@ -68,7 +69,7 @@
m_randomForest = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double]
ctypes,
Int num_trees = 16, Double sample_frac = 0.1, Double feature_frac = 1.0,
- Int max_depth = 10, Int min_leaf = 20, Double max_features = 0.5,
+ Int max_depth = 10, Int min_leaf = 20, Int min_split = 50, Double
max_features = 0.5,
String impurity = "gini", Int seed = -1, Boolean verbose = FALSE)
return(Matrix[Double] M)
{
@@ -81,6 +82,8 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] y,
Matrix[Double] cty
}
if(ncol(ctypes) != ncol(X)+1)
stop("randomForest: inconsistent num features (incl. label) and col types:
"+ncol(X)+" vs "+ncol(ctypes)+".");
+ if( sum(X<=0) != 0 )
+ stop("randomForest: feature matrix X is not properly recoded/binned:
"+sum(X<=0));
if(sum(y <= 0) != 0)
stop("randomForest: y is not properly recoded and binned (contiguous
positive integers).");
if(max(y) == 1)
@@ -91,16 +94,19 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double]
y, Matrix[Double] cty
# training of num_tree decision trees
M = matrix(0, rows=num_trees, cols=2*(2^max_depth-1));
- F = matrix(0, rows=num_trees, cols=ncol(X));
+ F = matrix(1, rows=num_trees, cols=ncol(X));
parfor(i in 1:num_trees) {
if( verbose )
print("randomForest: start training tree "+i+"/"+num_trees+".");
# step 1: sample data
- si1 = as.integer(as.scalar(randSeeds[3*(i-1)+1,1]));
- I1 = rand(rows=nrow(X), cols=1, seed=si1) <= sample_frac;
- Xi = removeEmpty(target=X, margin="rows", select=I1);
- yi = removeEmpty(target=y, margin="rows", select=I1);
+ Xi = X; yi = y;
+ if( sample_frac < 1.0 ) {
+ si1 = as.integer(as.scalar(randSeeds[3*(i-1)+1,1]));
+ I1 = rand(rows=nrow(X), cols=1, seed=si1) <= sample_frac;
+ Xi = removeEmpty(target=X, margin="rows", select=I1);
+ yi = removeEmpty(target=y, margin="rows", select=I1);
+ }
# step 2: sample features
if( feature_frac < 1.0 ) {
@@ -116,7 +122,7 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double]
y, Matrix[Double] cty
# step 3: train decision tree
t2 = time();
si3 = as.integer(as.scalar(randSeeds[3*(i-1)+3,1]));
- Mtemp = decisionTree(X=Xi, y=yi, ctypes=ctypes, max_depth=max_depth,
+ Mtemp = decisionTree(X=Xi, y=yi, ctypes=ctypes, max_depth=max_depth,
min_split=min_split,
min_leaf=min_leaf, max_features=max_features, impurity=impurity,
seed=si3, verbose=verbose);
M[i,1:length(Mtemp)] = matrix(Mtemp, rows=1, cols=length(Mtemp));
if( verbose )
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index 5b6f2e4e3d..f14afa9009 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -234,6 +234,9 @@ public class ReorgSPInstruction extends UnarySPInstruction {
boolean ixret =
sec.getScalarInput(_ixret).getBooleanValue();
mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(),
mc1.getBlocksize(), mc1.getBlocksize());
}
+ else { //e.g., rev
+ mcOut.set(mc1);
+ }
}
//infer initially unknown nnz from input
@@ -241,7 +244,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
boolean sortIx = getOpcode().equalsIgnoreCase("rsort")
&& sec.getScalarInput(_ixret.getName(), _ixret.getValueType(),
_ixret.isLiteral()).getBooleanValue();
if( sortIx )
mcOut.setNonZeros(mc1.getRows());
- else //default (r', rdiag, rsort data)
+ else //default (r', rdiag, rev, rsort data)
mcOut.setNonZeros(mc1.getNonZeros());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 34a5287b70..8fbfe31125 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -593,7 +593,10 @@ public class ProgramConverter
ret.setReadVariables( sb.variablesRead() );
//deep copy hops dag for concurrent recompile
- ArrayList<Hop> hops =
Recompiler.deepCopyHopsDag( sb.getHops() );
+ ArrayList<Hop> hops = sb.getHops();
+ synchronized(hops) { // guard concurrent
recompile
+ hops = Recompiler.deepCopyHopsDag( hops
);
+ }
if( !plain )
Recompiler.updateFunctionNames( hops,
pid );
ret.setHops( hops );
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
index 2af6784c36..f797cfff09 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
@@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
import org.junit.Test;
@@ -42,10 +43,22 @@ public class BuiltinDecisionTreeRealDataTest extends
AutomatedTestBase {
@Test
public void testDecisionTreeTitanic() {
- runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875,
ExecType.CP);
+ runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, 1,
ExecType.CP);
+ }
+
+ @Test
+ public void testRandomForestTitanic1() {
+ //one tree with sample_frac=1 should be equivalent to decision
tree
+ runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, 2,
ExecType.CP);
+ }
+
+ @Test
+ public void testRandomForestTitanic8() {
+ //8 trees with sample fraction 0.125 each, accuracy 0.785 due
to randomness
+ runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.793, 9,
ExecType.CP);
}
- private void runDecisionTree(String data, String tfspec, double minAcc,
ExecType instType) {
+ private void runDecisionTree(String data, String tfspec, double minAcc,
int dt, ExecType instType) {
Types.ExecMode platformOld = setExecMode(instType);
try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));
@@ -53,12 +66,13 @@ public class BuiltinDecisionTreeRealDataTest extends
AutomatedTestBase {
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats",
- "-args", data, tfspec, output("R")};
+ "-args", data, tfspec, String.valueOf(dt),
output("R")};
runTest(true, false, null, -1);
double acc = readDMLMatrixFromOutputDir("R").get(new
CellIndex(1,1));
Assert.assertTrue(acc >= minAcc);
+ Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
}
finally {
rtplatform = platformOld;
diff --git a/src/test/scripts/functions/builtin/decisionTreeRealData.dml
b/src/test/scripts/functions/builtin/decisionTreeRealData.dml
index 775a73a263..f61b2de77e 100644
--- a/src/test/scripts/functions/builtin/decisionTreeRealData.dml
+++ b/src/test/scripts/functions/builtin/decisionTreeRealData.dml
@@ -30,11 +30,20 @@ Y = X[, ncol(X)]
X = X[, 1:ncol(X)-1]
X = imputeByMode(X);
-M = decisionTree(X=X, y=Y, ctypes=R, max_features=1, min_split=8, min_leaf=5,
verbose=TRUE);
-yhat = decisionTreePredict(X=X, y=Y, ctypes=R, M=M)
+if( $3==1 ) {
+ M = decisionTree(X=X, y=Y, ctypes=R, max_features=1,
+ min_split=10, min_leaf=4, seed=7, verbose=TRUE);
+ yhat = decisionTreePredict(X=X, y=Y, ctypes=R, M=M)
+}
+else {
+ sf = 1.0/($3-1);
+ M = randomForest(X=X, y=Y, ctypes=R, sample_frac=sf, num_trees=$3-1,
max_features=1,
+ min_split=10, min_leaf=4, seed=7, verbose=TRUE);
+ yhat = randomForestPredict(X=X, y=Y, ctypes=R, M=M)
+}
acc = as.matrix(mean(yhat == Y))
err = 1-(acc);
-print("accuracy of DT: "+as.scalar(acc))
+print("accuracy: "+as.scalar(acc))
-write(acc, $3);
+write(acc, $4);