This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 6a318ecee6 [MINOR] Additional real-data decision tree tests (existing 
datasets)
6a318ecee6 is described below

commit 6a318ecee624ab448951f53bb9f4443bb510c4ad
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Jun 11 21:28:11 2023 +0200

    [MINOR] Additional real-data decision tree tests (existing datasets)
---
 .../part1/BuiltinDecisionTreeRealDataTest.java     | 18 +++++++-
 src/test/resources/datasets/EEG_tfspec.json        | 19 +++++++++
 .../functions/builtin/decisionTreeRealData4.dml    | 49 ++++++++++++++++++++++
 3 files changed, 84 insertions(+), 2 deletions(-)

diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
index 41bd6c651a..58ad08b625 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
@@ -37,10 +37,12 @@ public class BuiltinDecisionTreeRealDataTest extends 
AutomatedTestBase {
        private final static String TITANIC_TFSPEC = DATASET_DIR + 
"titanic/tfspec.json";
        private final static String WINE_DATA = DATASET_DIR + 
"wine/winequality-red-white.csv";
        private final static String WINE_TFSPEC = DATASET_DIR + 
"wine/tfspec.json";
+       private final static String EEG_DATA = DATASET_DIR + "EEG.csv";
+       private final static String EEG_TFSPEC = DATASET_DIR + 
"EEG_tfspec.json";
 
        @Override
        public void setUp() {
-               for(int i=1; i<=3; i++)
+               for(int i=1; i<=4; i++)
                        addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
        }
 
@@ -99,12 +101,24 @@ public class BuiltinDecisionTreeRealDataTest extends 
AutomatedTestBase {
                //for regression we compare R2 and use rss to optimize
                runDecisionTree(3, WINE_DATA, WINE_TFSPEC, 0.369, 2, 1.0, 
ExecType.CP);
        }
+       
+       @Test
+       public void testDecisionTreeEEG_MaxV1() {
+               //for regression we compare R2 and use rss to optimize
+               runDecisionTree(4, EEG_DATA, EEG_TFSPEC, 0.62, 1, 1.0, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testRandomForestEEG_MaxV1() {
+               //for regression we compare R2 and use rss to optimize
+               runDecisionTree(4, EEG_DATA, EEG_TFSPEC, 0.62, 2, 1.0, 
ExecType.CP);
+       }
 
        private void runDecisionTree(int test, String data, String tfspec, 
double minAcc, int dt, double maxV, ExecType instType) {
                Types.ExecMode platformOld = setExecMode(instType);
                try {
                        
loadTestConfiguration(getTestConfiguration(TEST_NAME+test));
-
+                       
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + (TEST_NAME+test) + ".dml";
                        programArgs = new String[] {"-stats",
diff --git a/src/test/resources/datasets/EEG_tfspec.json 
b/src/test/resources/datasets/EEG_tfspec.json
new file mode 100644
index 0000000000..361d0a85a1
--- /dev/null
+++ b/src/test/resources/datasets/EEG_tfspec.json
@@ -0,0 +1,19 @@
+{
+  "ids":true,
+  "recode":[1],
+  "bin":[
+    {"id":2, "method":"equi-width", "numbins":100},
+    {"id":3, "method":"equi-width", "numbins":100},
+    {"id":4, "method":"equi-width", "numbins":100},
+    {"id":5, "method":"equi-width", "numbins":100},
+    {"id":6, "method":"equi-width", "numbins":100},
+    {"id":7, "method":"equi-width", "numbins":100},
+    {"id":8, "method":"equi-width", "numbins":100},
+    {"id":9, "method":"equi-width", "numbins":100},
+    {"id":10, "method":"equi-width", "numbins":100},
+    {"id":11, "method":"equi-width", "numbins":100},
+    {"id":12, "method":"equi-width", "numbins":100},
+    {"id":13, "method":"equi-width", "numbins":100},
+    {"id":14, "method":"equi-width", "numbins":100},
+    {"id":15, "method":"equi-width", "numbins":100},]
+}
diff --git a/src/test/scripts/functions/builtin/decisionTreeRealData4.dml 
b/src/test/scripts/functions/builtin/decisionTreeRealData4.dml
new file mode 100644
index 0000000000..e1429f6499
--- /dev/null
+++ b/src/test/scripts/functions/builtin/decisionTreeRealData4.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F = read($1, data_type="frame", format="csv", header=TRUE);
+tfspec = read($2, data_type="scalar", value_type="string");
+
+R = matrix("1 1 1 1 1 1 1 1 1 1 1 1 1 1 2", rows=1, cols=15)
+
+[X, meta] = transformencode(target=F, spec=tfspec);
+Y = X[, 1]
+X = X[, 2:ncol(X)]
+X = imputeByMode(X);
+
+if( $3==1 ) {
+  M = decisionTree(X=X, y=Y, ctypes=R, max_features=1, max_values=$4,
+                   min_split=10, min_leaf=4, seed=7, verbose=TRUE);
+  yhat = decisionTreePredict(X=X, y=Y, ctypes=R, M=M)
+}
+else {
+  sf = 1.0/($3-1);
+  M = randomForest(X=X, y=Y, ctypes=R, sample_frac=sf, num_trees=$3-1,
+                   max_features=1, max_values=$4,
+                   min_split=10, min_leaf=4, seed=7, verbose=TRUE);
+  yhat = randomForestPredict(X=X, y=Y, ctypes=R,  M=M)
+}
+
+acc = as.matrix(mean(yhat == Y))
+err = 1-(acc);
+print("accuracy: "+as.scalar(acc))
+
+write(acc, $5);

Reply via email to