This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 44e4294 [SYSTEMDS-2870] Add Builtin Decision Tree DIA project
WS2020/21 Closes #1145
44e4294 is described below
commit 44e42945541150d7e870c5b89e79c085d63b4e78
Author: Patrick Gaisberger <[email protected]>
AuthorDate: Fri Nov 27 18:15:52 2020 +0100
[SYSTEMDS-2870] Add Builtin Decision Tree
DIA project WS2020/21
Closes #1145
Co-authored-by: David Egger <[email protected]>
Co-authored-by: Adna Ribo <[email protected]>
---
scripts/builtin/decisionTree.dml | 438 +++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../functions/builtin/BuiltinDecisionTreeTest.java | 95 +++++
.../test/functions/builtin/BuiltinLmTest.java | 2 +-
.../scripts/functions/builtin/decisionTree.dml | 26 ++
5 files changed, 561 insertions(+), 1 deletion(-)
diff --git a/scripts/builtin/decisionTree.dml b/scripts/builtin/decisionTree.dml
new file mode 100644
index 0000000..ac9f82b
--- /dev/null
+++ b/scripts/builtin/decisionTree.dml
@@ -0,0 +1,438 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+#
+# THIS SCRIPT IMPLEMENTS CLASSIFICATION TREES WITH BOTH SCALE AND CATEGORICAL
FEATURES
+#
+# INPUT PARAMETERS:
+#
---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+#
---------------------------------------------------------------------------------------------
+# X Matrix[Double] --- Feature matrix X; note that X needs to be
both recoded and dummy coded
+# Y Matrix[Double] --- Label matrix Y; note that Y needs to be
both recoded and dummy coded
+# R Matrix[Double] " " Matrix R which for each feature in X
contains the following information
+# - R[1,]: Row Vector which indicates if
feature vector is scalar or categorical. 1 indicates
+# a scalar feature vector, other positive
Integers indicate the number of categories
+# If R is not provided by default all
variables are assumed to be scale
+# bins Integer 20 Number of equiheight bins per scale
feature to choose thresholds
+# depth Integer 25 Maximum depth of the learned tree
+# verbose Boolean FALSE boolean specifying if the algorithm should
print information while executing
+#
---------------------------------------------------------------------------------------------
+# OUTPUT:
+# Matrix M where each column corresponds to a node in the learned tree and
each row contains the following information:
+# M[1,j]: id of node j (in a complete binary tree)
+# M[2,j]: Offset (no. of columns) to left child of j if j is an internal
node, otherwise 0
+# M[3,j]: Feature index of the feature (scale feature id if the feature is
scale or categorical feature id if the feature is categorical)
+# that node j looks at if j is an internal node, otherwise 0
+# M[4,j]: Type of the feature that node j looks at if j is an internal node:
holds the same information as R input vector
+# M[5,j]: If j is an internal node: 1 if the feature chosen for j is scale,
otherwise the size of the subset of values
+# stored in rows 6,7,... if j is categorical
+# If j is a leaf node: number of misclassified samples reaching at node j
+# M[6:,j]: If j is an internal node: Threshold the example's feature value is
compared to is stored at M[6,j] if the feature chosen for j is scale,
+# otherwise if the feature chosen for j is categorical rows 6,7,... depict
the value subset chosen for j
+# If j is a leaf node 1 if j is impure and the number of samples at j >
threshold, otherwise 0
+#
-------------------------------------------------------------------------------------------
+
+m_decisionTree = function(
+ Matrix[Double] X,
+ Matrix[Double] Y,
+ Matrix[Double] R,
+ Integer bins = 10,
+ Integer depth = 20,
+ Boolean verbose = FALSE
+) return (Matrix[Double] M) {
+ if (verbose) {
+ print("Executing Decision Tree:")
+ }
+ node_queue = matrix(1, rows=1, cols=1) # Add first Node
+ impurity_queue = matrix(1, rows=1, cols=1)
+ use_cols_queue = matrix(1, rows=ncol(X), cols=1) # Add fist bool Vector
with all cols <=> (use all cols)
+ use_rows_queue = matrix(1, rows=nrow(X), cols=1) # Add fist bool Vector
with all rows <=> (use all rows)
+ queue_length = 1
+ M = matrix(0, rows = 0, cols = 0)
+ while (queue_length > 0) {
+ [node_queue, node] = dataQueuePop(node_queue)
+ [use_rows_queue, use_rows_vector] = dataQueuePop(use_rows_queue)
+ [use_cols_queue, use_cols_vector] = dataQueuePop(use_cols_queue)
+
+ available_rows = calcAvailable(use_rows_vector)
+ available_cols = calcAvailable(use_cols_vector)
+ [impurity_queue, parent_impurity] = dataQueuePop(impurity_queue)
+ create_child_nodes_flag = FALSE
+ if (verbose) {
+ print("Popped Node: " + as.scalar(node))
+ print("Rows: " + toString(t(use_rows_vector)))
+ print("Cols: " + toString(t(use_cols_vector)))
+ print("Available Rows: " + available_rows)
+ print("Available Cols: " + available_cols)
+ print("Parent impurity: " + as.scalar(parent_impurity))
+ }
+
+ node_depth = calculateNodeDepth(node)
+ used_col = 0.0
+ if (node_depth < depth & available_rows > 1 & available_cols > 0 &
as.scalar(parent_impurity) > 0) {
+ [impurity, used_col, threshold, type] = calcBestSplittingCriteria(X, Y,
R, use_rows_vector, use_cols_vector, bins)
+ create_child_nodes_flag = impurity < as.scalar(parent_impurity)
+ if (verbose) {
+ print("Current impurity: " + impurity)
+ print("Current threshold: "+ toString(t(threshold)))
+ }
+ }
+ if (verbose) {
+
+ print("Current column: " + used_col)
+ print("Current type: " + type)
+ }
+ if (create_child_nodes_flag) {
+ [left, right] = calculateChildNodes(node)
+ node_queue = dataQueuePush(left, right, node_queue)
+
+ [new_use_cols_vector, left_use_rows_vector, right_use_rows_vector] =
splitData(X, use_rows_vector, use_cols_vector, used_col, threshold, type)
+ use_rows_queue = dataQueuePush(left_use_rows_vector,
right_use_rows_vector, use_rows_queue)
+ use_cols_queue = dataQueuePush(new_use_cols_vector, new_use_cols_vector,
use_cols_queue)
+
+ impurity_queue = dataQueuePush(matrix(impurity, rows = 1, cols = 1),
matrix(impurity, rows = 1, cols = 1), impurity_queue)
+ offset = dataQueueLength(node_queue) - 1
+ M = outputMatrixBind(M, node, offset, used_col, R, threshold)
+ } else {
+ M = outputMatrixBind(M, node, 0.0, used_col, R, matrix(0, rows = 1, cols
= 1))
+ }
+ queue_length = dataQueueLength(node_queue)# -- user-defined function calls
not supported in relational expressions
+
+ if (verbose) {
+ print("New QueueLen: " + queue_length)
+ print("")
+ }
+ }
+}
+
+dataQueueLength = function(Matrix[Double] queue) return (Double len) {
+ len = ncol(queue)
+}
+
+dataQueuePop = function(Matrix[Double] queue) return (Matrix[Double]
new_queue, Matrix[Double] node) {
+ node = matrix(queue[,1], rows=1, cols=nrow(queue)) # reshape to force the
creation of a new object
+ node = matrix(node, rows=nrow(queue), cols=1) # reshape to force the
creation of a new object
+ len = dataQueueLength(queue)
+ if (len < 2) {
+ new_queue = matrix(0,0,0)
+ } else {
+ new_queue = matrix(queue[,2:ncol(queue)], rows=nrow(queue),
cols=ncol(queue)-1)
+ }
+}
+
+dataQueuePush = function(Matrix[Double] left, Matrix[Double] right,
Matrix[Double] queue) return (Matrix[Double] new_queue) {
+ len = dataQueueLength(queue)
+ if(len <= 0) {
+ new_queue = cbind(left, right)
+ } else {
+ new_queue = cbind(queue, left, right)
+ }
+}
+
+dataVectorLength = function(Matrix[Double] vector) return (Double len) {
+ len = nrow(vector)
+}
+
+dataColVectorLength = function(Matrix[Double] vector) return (Double len) {
+ len = ncol(vector)
+}
+
+dataVectorGet = function(Matrix[Double] vector, Double index) return (Double
value) {
+ value = as.scalar(vector[index, 1])
+}
+
+dataVectorSet = function(Matrix[Double] vector, Double index, Double data)
return (Matrix[Double] new_vector) {
+ vector[index, 1] = data
+ new_vector = vector
+}
+
+calcAvailable = function(Matrix[Double] vector) return(Double
available_elements){
+ len = dataVectorLength(vector)
+ available_elements = 0.0
+ for (index in 1:len) {
+ element = dataVectorGet(vector, index)
+ if(element > 0.0) {
+ available_elements = available_elements + 1.0
+ }
+ }
+}
+
+calculateNodeDepth = function(Matrix[Double] node) return(Double depth) {
+ depth = log(as.scalar(node), 2) + 1
+}
+
+calculateChildNodes = function(Matrix[Double] node) return(Matrix[Double]
left, Matrix[Double] right) {
+ left = node * 2.0
+ right = node * 2.0 + 1.0
+}
+
+getTypeOfCol = function(Matrix[Double] R, Double col) return(Double type) {
# 1..scalar, 2..categorical
+ type = as.scalar(R[1, col])
+}
+
+extrapolateOrderedScalarFeatures = function(
+ Matrix[Double] X,
+ Matrix[Double] use_rows_vector,
+ Double col) return (Matrix[Double] feature_vector) {
+ feature_vector = matrix(1, rows = 1, cols = 1)
+ len = nrow(X)
+ first_time = TRUE
+ for(row in 1:len) {
+ use_feature = dataVectorGet(use_rows_vector, row)
+ if (use_feature != 0) {
+ if(first_time) {
+ feature_vector[1,1] = X[row, col]
+ first_time = FALSE
+ } else {
+ feature_vector = rbind(feature_vector, X[row, col])
+ }
+ }
+ }
+ feature_vector = order(target=feature_vector, by=1, decreasing=FALSE,
index.return=FALSE)
+}
+
+calcPossibleThresholdsScalar = function(
+ Matrix[Double] X,
+ Matrix[Double] use_rows_vector,
+ Double col,
+ int bins) return (Matrix[Double] thresholds) {
+ ordered_features = extrapolateOrderedScalarFeatures(X, use_rows_vector, col)
+ ordered_features_len = dataVectorLength(ordered_features)
+ thresholds = matrix(1, rows = 1, cols = ordered_features_len - 1)
+ virtual_length = min(ordered_features_len, 20)
+ step_length = ordered_features_len / virtual_length
+ if (ordered_features_len > 1) {
+ for (index in 1:(virtual_length - 1)) {
+ real_index = index * step_length
+ mean = (dataVectorGet(ordered_features, real_index) +
dataVectorGet(ordered_features, real_index + 1)) / 2
+ thresholds[1, index] = mean
+ }
+ }
+}
+
+calcPossibleThresholdsCategory = function(Double type) return (Matrix[Double]
thresholds) {
+ numberThresholds = 2 ^ type
+ thresholds = matrix(-1, rows = type, cols = numberThresholds)
+ toggleFactor = numberThresholds / 2
+
+ for (index in 1:type) {
+ beginCols = 1
+ endCols = toggleFactor
+ iterations = numberThresholds / toggleFactor / 2
+ for (it in 1:iterations) {
+ category_val = type - index + 1
+ thresholds[index, beginCols:endCols] = matrix(category_val, rows = 1,
cols = toggleFactor)
+ endCols = endCols + 2 * toggleFactor
+ beginCols = beginCols + 2 * toggleFactor
+ }
+ toggleFactor = toggleFactor / 2
+ iterations = numberThresholds / toggleFactor / 2
+ }
+ ncol = ncol(thresholds)
+ if (ncol > 2.0) {
+ thresholds = cbind(thresholds[,2:ncol-2], thresholds[,ncol-1])
+ }
+}
+
+calcGiniImpurity = function(Double num_true, Double num_false) return (Double
impurity) {
+ prop_true = num_true / (num_true + num_false)
+ prop_false = num_false / (num_true + num_false)
+ impurity = 1 - (prop_true ^ 2) - (prop_false ^ 2)
+}
+
+calcImpurity = function(
+ Matrix[Double] X,
+ Matrix[Double] Y,
+ Matrix[Double] use_rows_vector,
+ Double col,
+ Double type,
+ int bins) return (Double impurity, Matrix[Double] threshold) {
+
+ is_scalar_type = typeIsScalar(type)
+ if (is_scalar_type) {
+ possible_thresholds = calcPossibleThresholdsScalar(X, use_rows_vector,
col, bins)
+ } else {
+ possible_thresholds = calcPossibleThresholdsCategory(type)
+ }
+ len_thresholds = ncol(possible_thresholds)
+ impurity = 1
+ threshold = matrix(0, rows=1, cols=1)
+ for (index in 1:len_thresholds) {
+ [false_rows, true_rows] = splitRowsVector(X, use_rows_vector, col,
possible_thresholds[, index], type)
+ num_true_positive = 0; num_false_positive = 0; num_true_negative = 0;
num_false_negative = 0
+ len = dataVectorLength(use_rows_vector)
+ for (c_row in 1:len) {
+ true_row_data = dataVectorGet(true_rows, c_row)
+ false_row_data = dataVectorGet(false_rows, c_row)
+ if (true_row_data != 0 & false_row_data == 0) { # IT'S POSITIVE!
+ if (as.scalar(Y[c_row, 1]) != 0) {
+ num_true_positive = num_true_positive + 1
+ } else {
+ num_false_positive = num_false_positive + 1
+ }
+ } else if (true_row_data == 0 & false_row_data != 0) { # IT'S NEGATIVE
+ if (as.scalar(Y[c_row, 1]) != 0.0) {
+ num_false_negative = num_false_negative + 1
+ } else {
+ num_true_negative = num_true_negative + 1
+ }
+ }
+ }
+ impurity_positive_branch = calcGiniImpurity(num_true_positive,
num_false_positive)
+ impurity_negative_branch = calcGiniImpurity(num_true_negative,
num_false_negative)
+ num_samples = num_true_positive + num_false_positive + num_true_negative +
num_false_negative
+ num_negative = num_true_negative + num_false_negative
+ num_positive = num_true_positive + num_false_positive
+ c_impurity = num_positive / num_samples * impurity_positive_branch +
num_negative / num_samples * impurity_negative_branch
+ if (c_impurity <= impurity) {
+ impurity = c_impurity
+ threshold = possible_thresholds[, index]
+ }
+ }
+}
+
+calcBestSplittingCriteria = function(
+ Matrix[Double] X,
+ Matrix[Double] Y,
+ Matrix[Double] R,
+ Matrix[Double] use_rows_vector,
+ Matrix[Double] use_cols_vector,
+ int bins) return (Double impurity, Double used_col, Matrix[Double]
threshold, Double type) {
+
+ impurity = 1
+ used_col = 1
+ threshold = matrix(0, 1, 1)
+ type = 1
+ # -- user-defined function calls not supported for iterable predicates
+ len = dataVectorLength(use_cols_vector)
+ for (c_col in 1:len) {
+ use_feature = dataVectorGet(use_cols_vector, c_col)
+ if (use_feature != 0) {
+ c_type = getTypeOfCol(R, c_col)
+ [c_impurity, c_threshold] = calcImpurity(X, Y, use_rows_vector, c_col,
c_type, bins)
+ if(c_impurity <= impurity) {
+ impurity = c_impurity
+ used_col = c_col
+ threshold = c_threshold
+ type = c_type
+ }
+ }
+ }
+}
+
+typeIsScalar = function(Double type) return(Boolean b) {
+ b = type == 1.0
+}
+
+splitRowsVector = function(
+ Matrix[Double] X,
+ Matrix[Double] use_rows_vector,
+ Double col,
+ Matrix[Double] threshold,
+ Double type
+) return (Matrix[Double] false_use_rows_vector, Matrix[Double]
true_use_rows_vector) {
+ type_is_scalar = typeIsScalar(type)
+ false_use_rows_vector = use_rows_vector
+ true_use_rows_vector = use_rows_vector
+
+ if (type_is_scalar) {
+ scalar_threshold = as.scalar(threshold[1,1])
+ len = dataVectorLength(use_rows_vector)
+ for (c_row in 1:len) {
+ row_enabled = dataVectorGet(use_rows_vector, c_row)
+ if (row_enabled != 0) {
+ if (as.scalar(X[c_row, col]) > scalar_threshold) {
+ false_use_rows_vector = dataVectorSet(false_use_rows_vector, c_row,
0.0)
+ } else {
+ true_use_rows_vector = dataVectorSet(true_use_rows_vector, c_row, 0.0)
+ }
+ }
+ }
+ } else {
+ len = dataVectorLength(use_rows_vector)
+ for (c_row in 1:len) {
+ row_enabled = dataVectorGet(use_rows_vector, c_row)
+ if (row_enabled != 0) {
+ categories_len = dataColVectorLength(threshold)
+ move_sample_to_true_set = FALSE
+ for (category_col_index in 1:categories_len) {
+ desired_category = as.scalar(X[c_row, col])
+ if(desired_category != -1) {
+ category_of_threshold = threshold[type - desired_category + 1,
category_col_index]
+ move_sample_to_true_set = as.scalar(X[c_row, col]) ==
as.scalar(category_of_threshold)
+ } else {
+ #Todo: has category -1 to be considered?
+ move_sample_to_true_set = TRUE
+ }
+ }
+ if (move_sample_to_true_set) {
+ false_use_rows_vector = dataVectorSet(false_use_rows_vector, c_row,
0.0)
+ } else {
+ true_use_rows_vector = dataVectorSet(true_use_rows_vector, c_row,
0.0)
+ }
+ }
+ }
+ }
+}
+
+splitData = function(
+ Matrix[Double] X,
+ Matrix[Double] use_rows_vector,
+ Matrix[Double] use_cols_vector,
+ Double col,
+ Matrix[Double] threshold,
+ Double type
+) return (Matrix[Double] new_use_cols_vector, Matrix[Double]
false_use_rows_vector, Matrix[Double] true_use_rows_vector) {
+ new_use_cols_vector = dataVectorSet(use_cols_vector, col, 0.0)
+ [false_use_rows_vector, true_use_rows_vector] = splitRowsVector(X,
use_rows_vector, col, threshold, type)
+}
+
+outputMatrixBind = function(
+ Matrix[Double] M,
+ Matrix[Double] node,
+ Double offset,
+ Double used_col,
+ Matrix[Double] R,
+ Matrix[Double] threshold
+) return (Matrix[Double] new_M) {
+ col = matrix(0, rows = 5, cols = 1)
+ col[1, 1] = node[1, 1]
+ col[2, 1] = offset
+ col[3, 1] = used_col
+ if (used_col >= 1.0) { col[4, 1] = R[1, used_col] }
+ col[5, 1] = nrow(threshold)
+ col = rbind(col, threshold)
+
+ if (ncol(M) == 0 & nrow(M) == 0) {
+ new_M = col
+ } else {
+ row_difference = nrow(M) - nrow(col)
+ if (row_difference < 0.0) {
+ buffer = matrix(-1, rows = -row_difference, cols = ncol(M))
+ M = rbind(M, buffer)
+ } else if (row_difference > 0.0) {
+ buffer = matrix(-1, rows = row_difference, cols = 1)
+ col = rbind(col, buffer)
+ }
+ new_M = cbind(M, col)
+ }
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 9080136..af357b8 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -138,6 +138,7 @@ public enum Builtins {
KMEANSPREDICT("kmeansPredict", true),
KNNBF("knnbf", true),
KNN("knn", true),
+ DECISIONTREE("decisionTree", true),
L2SVM("l2svm", true),
LASSO("lasso", true),
LENGTH("length", false),
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDecisionTreeTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDecisionTreeTest.java
new file mode 100644
index 0000000..6520f1b
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDecisionTreeTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class BuiltinDecisionTreeTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "decisionTree";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinDecisionTreeTest.class.getSimpleName() + "/";
+
+ private final static double eps = 1e-10;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
+ }
+
+ @Test
+ public void testDecisionTreeDefaultCP() {
+ runDecisionTree(true, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void testDecisionTreeSP() {
+ runDecisionTree(true, LopProperties.ExecType.SPARK);
+ }
+
+ private void runDecisionTree(boolean defaultProb,
LopProperties.ExecType instType) {
+ Types.ExecMode platformOld = setExecMode(instType);
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", input("X"),
input("Y"), input("R"), output("M")};
+
+ double[][] Y = {{1.0}, {0.0}, {0.0}, {1.0}, {0.0}};
+
+ double[][] X = {{4.5, 4.0, 3.0, 2.8, 3.5}, {1.9, 2.4,
1.0, 3.4, 2.9}, {2.0, 1.1, 1.0, 4.9, 3.4},
+ {2.3, 5.0, 2.0, 1.4, 1.8}, {2.1, 1.1, 3.0, 1.0,
1.9},};
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+
+ double[][] R = {{1.0, 1.0, 3.0, 1.0, 1.0},};
+ writeInputMatrixWithMTD("R", R, true);
+
+ runTest(true, false, null, -1);
+
+ HashMap<MatrixValue.CellIndex, Double> actual_M =
readDMLMatrixFromOutputDir("M");
+ HashMap<MatrixValue.CellIndex, Double> expected_M = new
HashMap<MatrixValue.CellIndex, Double>();
+
+ expected_M.put(new MatrixValue.CellIndex(1, 1), 1.0);
+ expected_M.put(new MatrixValue.CellIndex(1, 3), 3.0);
+ expected_M.put(new MatrixValue.CellIndex(3, 1), 2.0);
+ expected_M.put(new MatrixValue.CellIndex(1, 2), 2.0);
+ expected_M.put(new MatrixValue.CellIndex(2, 1), 1.0);
+ expected_M.put(new MatrixValue.CellIndex(5, 1), 1.0);
+ expected_M.put(new MatrixValue.CellIndex(4, 1), 1.0);
+ expected_M.put(new MatrixValue.CellIndex(5, 3), 1.0);
+ expected_M.put(new MatrixValue.CellIndex(5, 2), 1.0);
+ expected_M.put(new MatrixValue.CellIndex(6, 1), 3.2);
+
+ TestUtils.compareMatrices(expected_M, actual_M, eps,
"Expected-DML", "Actual-DML");
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java
index df87f61..65b5372 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java
@@ -47,7 +47,7 @@ public class BuiltinLmTest extends AutomatedTestBase
@Override
public void setUp() {
- addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
}
@Test
diff --git a/src/test/scripts/functions/builtin/decisionTree.dml
b/src/test/scripts/functions/builtin/decisionTree.dml
new file mode 100644
index 0000000..31bd820
--- /dev/null
+++ b/src/test/scripts/functions/builtin/decisionTree.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+Y = read($2);
+R = read($3)
+M = decisionTree(X = X, Y = Y, R = R);
+write(M, $4);