This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 066c0aa2ae [SYSTEMDS-3777] Additional adasyn real data tests
066c0aa2ae is described below
commit 066c0aa2ae69b69e2bae954e032083507f879705
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Nov 18 10:17:32 2024 +0100
[SYSTEMDS-3777] Additional adasyn real data tests
This generalizes the adasyn test for additional real data set. On the
titantic dataset, adasyn gives a 1.6% improvement of test accuracy
(for a basic logreg model, 0.781 -> 0.797).
---
.../org/apache/sysds/parser/DMLTranslator.java | 2 +-
.../builtin/part1/BuiltinAdasynRealDataTest.java | 21 +++++++++++++++--
src/test/resources/datasets/diabetes/tfspec.json | 1 +
.../scripts/functions/builtin/adasynRealData.dml | 26 +++++++++++++++++-----
4 files changed, 41 insertions(+), 9 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index b76425668f..6121711933 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1610,7 +1610,7 @@ public class DMLTranslator
return currBuiltinOp;
}
else{
- throw new ParseException("Unhandled instance of
source type: " + source.getClass());
+ throw new ParseException("Unhandled instance of
source type: " + source);
}
}
catch(ParseException e ){
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
index e310fd877a..0646c2f5d1 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
@@ -35,6 +35,8 @@ public class BuiltinAdasynRealDataTest extends
AutomatedTestBase {
private final static String DIABETES_DATA = DATASET_DIR +
"diabetes/diabetes.csv";
private final static String DIABETES_TFSPEC = DATASET_DIR +
"diabetes/tfspec.json";
+ private final static String TITANIC_DATA = DATASET_DIR +
"titanic/titanic.csv";
+ private final static String TITANIC_TFSPEC = DATASET_DIR +
"titanic/tfspec.json";
@Override
public void setUp() {
@@ -56,6 +58,21 @@ public class BuiltinAdasynRealDataTest extends
AutomatedTestBase {
runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 6,
ExecType.CP);
}
+ @Test
+ public void testTitanicNoAdasyn() {
+ runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, false, 0.781, -1,
ExecType.CP);
+ }
+
+ @Test
+ public void testTitanicAdasynK4() {
+ runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, true, 0.797, 4,
ExecType.CP);
+ }
+
+ @Test
+ public void testTitanicAdasynK5() {
+ runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, true, 0.797, 5,
ExecType.CP);
+ }
+
private void runAdasynTest(String data, String tfspec, boolean adasyn,
double minAcc, int k, ExecType instType) {
Types.ExecMode platformOld = setExecMode(instType);
try {
@@ -63,8 +80,8 @@ public class BuiltinAdasynRealDataTest extends
AutomatedTestBase {
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats",
- "-args", data, String.valueOf(adasyn),
String.valueOf(k), output("R")};
+ programArgs = new String[] {"-stats", "-args",
+ data, tfspec, String.valueOf(adasyn),
String.valueOf(k), output("R")};
runTest(true, false, null, -1);
diff --git a/src/test/resources/datasets/diabetes/tfspec.json
b/src/test/resources/datasets/diabetes/tfspec.json
new file mode 100644
index 0000000000..8d1c8b69c3
--- /dev/null
+++ b/src/test/resources/datasets/diabetes/tfspec.json
@@ -0,0 +1 @@
+
diff --git a/src/test/scripts/functions/builtin/adasynRealData.dml
b/src/test/scripts/functions/builtin/adasynRealData.dml
index 6e401ec336..80a2d83347 100644
--- a/src/test/scripts/functions/builtin/adasynRealData.dml
+++ b/src/test/scripts/functions/builtin/adasynRealData.dml
@@ -20,16 +20,30 @@
#-------------------------------------------------------------
-M = read($1, data_type="matrix", format="csv", header=TRUE);
-Y = M[, ncol(M)] + 1
-X = M[, 1:ncol(M)-1]
-upsample = as.logical($2)
+M = read($1, data_type="frame", format="csv", header=TRUE,
+ naStrings= ["NA", "null"," ","NaN", "nan", "", " ", "_nan_", "inf",
"?", "NAN", "99999", "99999.00"]);
+Y = as.matrix(M[, ncol(M)]) + 1
+F = M[, 1:ncol(M)-1]
+tfspec = read($2, data_type="scalar", value_type="string")
+upsample = as.logical($3)
+
+if( tfspec != " " ) {
+ F = M[, 1:ncol(M)] # FIXME
+ [X,meta] = transformencode(target=F, spec=tfspec);
+ X = X[,1:ncol(X)-1];
+ X = imputeByMode(X);
+}
+else {
+ X = as.matrix(F);
+}
+
+[X,C,S] = scale(X=X, scale=TRUE, center=TRUE);
[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7, seed=3);
if( upsample ) {
# oversampling all classes other than majority
- [Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3, seed=7);
+ [Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$4, seed=7);
}
B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);
@@ -37,5 +51,5 @@ B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);
print("accuracy: "+acc)
R = as.matrix(acc/100);
-write(R, $4);
+write(R, $5);