This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 11418431bc [SYSTEMDS-3506] Finalize DecisionTree Predict Methods
11418431bc is described below
commit 11418431bc995a02d6910331d3fe72492f2bb1b6
Author: e-strauss <[email protected]>
AuthorDate: Fri Aug 23 20:09:02 2024 +0200
[SYSTEMDS-3506] Finalize DecisionTree Predict Methods
Closes #2069.
---
scripts/builtin/decisionTreePredict.dml | 66 +++++++--
.../part1/BuiltinDecisionTreePredictTest.java | 161 +++++++++++++++++----
.../builtin/part1/BuiltinDecisionTreeTest.java | 2 +-
3 files changed, 191 insertions(+), 38 deletions(-)
diff --git a/scripts/builtin/decisionTreePredict.dml
b/scripts/builtin/decisionTreePredict.dml
index 50a8c2a912..4585d56d08 100644
--- a/scripts/builtin/decisionTreePredict.dml
+++ b/scripts/builtin/decisionTreePredict.dml
@@ -92,7 +92,7 @@ predict_GEMM = function (Matrix[Double] M, Matrix[Double] X)
[A, B, C, D, E] = createGEMMNodeTensors(M, ncol(X));
# scoring pipline, evaluating all nodes in parallel
- Y = rowIndexMax(((((X %*% A) < B) %*% C) == D) %*% E);
+ Y = rowIndexMax(((((X %*% A) <= B) %*% C) == D) %*% E);
}
createTTNodeTensors = function( Matrix[Double] M )
@@ -125,22 +125,64 @@ createGEMMNodeTensors = function( Matrix[Double] M, Int m
)
return (Matrix[Double] A, Matrix[Double] B, Matrix[Double] C,
Matrix[Double] D, Matrix[Double] E)
{
- #TODO update for new model layout and generalize
- stop("GEMM not fully supported yet");
-
- nin = sum(M[2,]!=0); # num inner nodes
+ M2 = matrix(M, rows=ncol(M)/2, cols=2)
+ NID = seq(1, nrow(M2))
# predicate map [#feat x #inodes] and values [1 x #inodes]
- I1 = removeEmpty(target=M[3,], margin="cols");
- A = table(I1, seq(1,nin), m, nin);
- B = removeEmpty(target=M[6,], margin="cols", select=M[2,]!=0);
+ is_inner = M2[,1]!=0
+ I1 = removeEmpty(target=NID, margin="rows", select=is_inner)
+ pivot = removeEmpty(target=M2[,1], margin="rows", select=is_inner)
+ nin = nrow(I1)
+ A = table(pivot, seq(1,nin), m, nin)
+ B = t(removeEmpty(target=M2[,2], margin="rows", select=is_inner))
# bucket paths [#inodes x #paths] and path sums
- I2 = (M[2,] == 0)
- np = ncol(M) - nin;
- C = matrix("1 -1", rows=1, cols=2); # TODO general case
+ is_leaf = (!is_inner & M2[,2]!=0)
+ leaf_ids = t(removeEmpty(target=NID, margin="rows", select=is_leaf))
+ last_leaf = as.scalar(leaf_ids[1,ncol(leaf_ids)])
+ leaf_classes = removeEmpty(target=M2[,2], margin="rows", select=is_leaf)
+ nl = ncol(leaf_ids)
+
+ # iterate over each inner node and check for each leaf node if it is the
left subtree (1), right subtree (-1) or not included (0)
+ # | i |
+ # / \
+ # |2i| |2i+1|
+ # / \ / \
+ # |4i| |4i+1| |4i+2| |4i+3|
+ #
+ # left_subtree_of_node(i) = { x | (2^j)*i <= x < (2^j)*i + 2^(j-1), for j
elem {1, 2, 3, ...}} -> j is the level of tree
+ # right_subtree_of_node(i) = { x | (2^j)*i + 2^(j-1) <= x < (2^j + 1)*i, for
j elem {1, 2, 3, ...}}
+
+ C = matrix(0, nin, nl)
+ parfor(i in seq(1, nin)){
+ boundary_left = 2*as.scalar(I1[i, 1]) # initialize the left boundary with
the left child of the inner node
+ out = matrix(0, 1, nl)
+ step_size = 1
+
+ # iterate each level of tree [log(max_node_id) iterations]
+ while(boundary_left < last_leaf) {
+
+ # left side
+ subset_lower_bound = leaf_ids >= boundary_left
+ boundary_right = boundary_left + step_size
+ subset_upper_bound = leaf_ids < boundary_right
+ ones = subset_lower_bound & subset_upper_bound
+ out = out + ones
+
+ # right side
+ subset_lower_bound = !subset_upper_bound #reuse by inverting
+ boundary_right = boundary_right + step_size
+ subset_upper_bound = leaf_ids < boundary_right
+ ones = subset_lower_bound & subset_upper_bound
+ out = out - ones
+
+ step_size = step_size*2 # with each level the amount of nodes in subtree
level doubles
+ boundary_left = boundary_left*2 # new left boundary is the left child of
the previous left boundary
+ }
+ C[i,] = out
+ }
D = colSums(max(C, 0));
# class map [#paths x #classes]
- E = table(seq(1,ncol(C)), t(M[4,(ncol(M)-ncol(C)+1):ncol(M)]));
+ E = table(seq(1,ncol(C)),leaf_classes)
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
index 0cecc0e15c..6eb22da335 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
@@ -27,13 +27,12 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
import org.junit.Test;
public class BuiltinDecisionTreePredictTest extends AutomatedTestBase {
private final static String TEST_NAME = "decisionTreePredict";
private final static String TEST_DIR = "functions/builtin/";
- private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinDecisionTreeTest.class.getSimpleName() + "/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinDecisionTreePredictTest.class.getSimpleName() + "/";
private final static double eps = 1e-10;
@@ -43,28 +42,58 @@ public class BuiltinDecisionTreePredictTest extends
AutomatedTestBase {
}
@Test
- public void testDecisionTreeTTPredictDefaultCP() {
- runDecisionTreePredict(true, ExecType.CP, "TT");
+ public void testDecisionTreeTTPredictDefaultCP1() {
+ runDecisionTreePredict(true, ExecType.CP, "TT", 1);
}
+ @Test
+ public void testDecisionTreeTTPredictDefaultCP2() {
+ runDecisionTreePredict(true, ExecType.CP, "TT", 2);
+ }
+
+ @Test
+ public void testDecisionTreeTTPredictDefaultCP3() {
+ runDecisionTreePredict(true, ExecType.CP, "TT", 3);
+ }
+
+ @Test
+ public void testDecisionTreeTTPredictDefaultCP4() {
+ runDecisionTreePredict(true, ExecType.CP, "TT", 4);
+ }
+
+
@Test
public void testDecisionTreeTTPredictSP() {
- runDecisionTreePredict(true, ExecType.SPARK, "TT");
+ runDecisionTreePredict(true, ExecType.SPARK, "TT", 1);
}
@Test
- @Ignore
- public void testDecisionTreeGEMMPredictDefaultCP() {
- runDecisionTreePredict(true, ExecType.CP, "GEMM");
+ public void testDecisionTreeGEMMPredictDefaultCP1() {
+ runDecisionTreePredict(true, ExecType.CP, "GEMM", 1);
+ }
+
+ @Test
+ public void testDecisionTreeGEMMPredictDefaultCP2() {
+ runDecisionTreePredict(true, ExecType.CP, "GEMM", 2);
}
@Test
- @Ignore
+ public void testDecisionTreeGEMMPredictDefaultCP3() {
+ runDecisionTreePredict(true, ExecType.CP, "GEMM", 3);
+ }
+
+ @Test
+ public void testDecisionTreeGEMMPredictDefaultCP4() {
+ runDecisionTreePredict(true, ExecType.CP, "GEMM", 4);
+ }
+
+
+ @Test
public void testDecisionTreeGEMMPredictSP() {
- runDecisionTreePredict(true, ExecType.SPARK, "GEMM");
+ runDecisionTreePredict(true, ExecType.SPARK, "GEMM", 1);
}
- private void runDecisionTreePredict(boolean defaultProb, ExecType
instType, String strategy) {
+ private void runDecisionTreePredict(boolean defaultProb, ExecType
instType, String strategy, int test_case) {
Types.ExecMode platformOld = setExecMode(instType);
try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));
@@ -74,20 +103,102 @@ public class BuiltinDecisionTreePredictTest extends
AutomatedTestBase {
programArgs = new String[] {"-args", input("M"),
input("X"), strategy, output("Y")};
//data and model consistent with decision tree test
- double[][] X = {
- {3, 1, 2, 1, 5},
- {2, 1, 2, 2, 4},
- {1, 1, 1, 3, 3},
- {4, 2, 1, 4, 2},
- {2, 2, 1, 5, 1},};
- double[][] M = {{1.0, 2.0, 0.0, 1.0, 0.0, 2.0}};
-
+ double[][] X = null;
+ double[][] M = null;
+
HashMap<MatrixValue.CellIndex, Double> expected_Y = new
HashMap<>();
- expected_Y.put(new MatrixValue.CellIndex(1, 1), 2.0);
- expected_Y.put(new MatrixValue.CellIndex(2, 1), 1.0);
- expected_Y.put(new MatrixValue.CellIndex(3, 1), 1.0);
- expected_Y.put(new MatrixValue.CellIndex(4, 1), 2.0);
- expected_Y.put(new MatrixValue.CellIndex(5, 1), 1.0);
+ switch(test_case){
+ case 1:
+ double[][] X1 = {
+ {3, 1, 2, 1, 5},
+ {2, 1, 2, 2, 4},
+ {1, 1, 1, 3, 3},
+ {4, 2, 1, 4, 2},
+ {2, 2, 1, 5, 1},};
+ double[][] M1 = {{1.0, 2.0, 0.0, 1.0,
0.0, 2.0}};
+
+ expected_Y.put(new
MatrixValue.CellIndex(1, 1), 2.0);
+ expected_Y.put(new
MatrixValue.CellIndex(2, 1), 1.0);
+ expected_Y.put(new
MatrixValue.CellIndex(3, 1), 1.0);
+ expected_Y.put(new
MatrixValue.CellIndex(4, 1), 2.0);
+ expected_Y.put(new
MatrixValue.CellIndex(5, 1), 1.0);
+ X = X1;
+ M = M1;
+ break;
+ case 2:
+ double[][] X2 = {
+ {3, 1, 2, 1},
+ {2, 1, 2, 6},
+ {1, 1, 1, 3},
+ {9, 2, 1, 7},
+ {2, 2, 1, 1},};
+ double[][] M2 = {{4, 5, 0, 2, 1, 7, 0,
0, 0, 0, 0, 2, 0, 1}};
+
+ expected_Y.put(new
MatrixValue.CellIndex(1, 1), 2.0);
+ expected_Y.put(new
MatrixValue.CellIndex(2, 1), 2.0);
+ expected_Y.put(new
MatrixValue.CellIndex(3, 1), 2.0);
+ expected_Y.put(new
MatrixValue.CellIndex(4, 1), 1.0);
+ expected_Y.put(new
MatrixValue.CellIndex(5, 1), 2.0);
+ X = X2;
+ M = M2;
+ break;
+ case 3:
+ double[][] X3 = {
+ {1, 1, 1},
+ {1, 1, 7,},
+ {1, 5, 1},
+ {1, 5, 7,},};
+ double[][] M3 = {{1, 5, 2, 4, 2, 4, 3,
6, 3, 6, 3, 6, 3, 6, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8}};
+
+ expected_Y.put(new
MatrixValue.CellIndex(1, 1), 1.0);
+ expected_Y.put(new
MatrixValue.CellIndex(2, 1), 2.0);
+ expected_Y.put(new
MatrixValue.CellIndex(3, 1), 3.0);
+ expected_Y.put(new
MatrixValue.CellIndex(4, 1), 4.0);
+ X = X3;
+ M = M3;
+ break;
+ case 4:
+ double[][] X4 = {
+ {1, 1, 1, 1},
+ {4, 1, 1, 1},
+ {1, 1, 7, 1},
+ {4, 1, 7, 1},
+ {1, 5, 1, 1},
+ {4, 5, 1, 1},
+ {1, 5, 7, 1},
+ {4, 5, 7, 1},
+ {1, 1, 1, 6},
+ {4, 1, 1, 6},
+ {1, 1, 7, 6},
+ {4, 1, 7, 6},
+ {1, 5, 1, 6},
+ {4, 5, 1, 6},
+ {1, 5, 7, 6},
+ {4, 5, 7, 6},};
+ double[][] M4 = {{4, 5, 2, 4, 2, 4, 3,
6, 3, 6, 3, 6, 3, 6, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1,
+ 3, 1, 3, 0, 1, 0, 2, 0,
3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13,
+ 0, 14, 0, 15, 0, 16}};
+
+ expected_Y.put(new
MatrixValue.CellIndex(1, 1), 1.0);
+ expected_Y.put(new
MatrixValue.CellIndex(2, 1), 2.0);
+ expected_Y.put(new
MatrixValue.CellIndex(3, 1), 3.0);
+ expected_Y.put(new
MatrixValue.CellIndex(4, 1), 4.0);
+ expected_Y.put(new
MatrixValue.CellIndex(5, 1), 5.0);
+ expected_Y.put(new
MatrixValue.CellIndex(6, 1), 6.0);
+ expected_Y.put(new
MatrixValue.CellIndex(7, 1), 7.0);
+ expected_Y.put(new
MatrixValue.CellIndex(8, 1), 8.0);
+ expected_Y.put(new
MatrixValue.CellIndex(9, 1), 9.0);
+ expected_Y.put(new
MatrixValue.CellIndex(10, 1), 10.0);
+ expected_Y.put(new
MatrixValue.CellIndex(11, 1), 11.0);
+ expected_Y.put(new
MatrixValue.CellIndex(12, 1), 12.0);
+ expected_Y.put(new
MatrixValue.CellIndex(13, 1), 13.0);
+ expected_Y.put(new
MatrixValue.CellIndex(14, 1), 14.0);
+ expected_Y.put(new
MatrixValue.CellIndex(15, 1), 15.0);
+ expected_Y.put(new
MatrixValue.CellIndex(16, 1), 16.0);
+ X = X4;
+ M = M4;
+ break;
+ }
writeInputMatrixWithMTD("M", M, true);
writeInputMatrixWithMTD("X", X, true);
@@ -98,7 +209,7 @@ public class BuiltinDecisionTreePredictTest extends
AutomatedTestBase {
TestUtils.compareMatrices(expected_Y, actual_Y, eps,
"Expected-DML", "Actual-DML");
}
finally {
- rtplatform = platformOld;
+ resetExecMode(platformOld);
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
index f8ac8397cb..a8b3112992 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
@@ -86,7 +86,7 @@ public class BuiltinDecisionTreeTest extends
AutomatedTestBase {
TestUtils.compareMatrices(expected_M, actual_M, eps,
"Expected-DML", "Actual-DML");
}
finally {
- rtplatform = platformOld;
+ resetExecMode(platformOld);
}
}
}