This is an automated email from the ASF dual-hosted git repository.
ssiddiqi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 97e14f8 [SYSTEMDS-3209] Builtin for random under-sampling - the
builtin accepts matrix data with last column as labels and a ratio parameter
and randomly remove the tuples from the majority class
97e14f8 is described below
commit 97e14f81bbfa440986e670216432afe67cc7c051
Author: Shafaq Siddiqi <[email protected]>
AuthorDate: Mon Nov 8 17:14:01 2021 +0100
[SYSTEMDS-3209] Builtin for random under-sampling
- the builtin accepts matrix data with last column as labels and a ratio
parameter
and randomly remove the tuples from the majority class
---
scripts/builtin/underSampling.dml | 49 ++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../builtin/BuiltinUnderSamplingTest.java | 79 ++++++++++++++++++++++
.../functions/builtin/underSamplingTest.dml | 36 ++++++++++
4 files changed, 165 insertions(+)
diff --git a/scripts/builtin/underSampling.dml
b/scripts/builtin/underSampling.dml
new file mode 100644
index 0000000..debc367
--- /dev/null
+++ b/scripts/builtin/underSampling.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# # # following built-in performs random under sampling on data
+
+underSampling = function(Matrix[Double] data, Double ratio)
+return(Matrix[Double] data)
+{
+ if(ratio < 0 | ratio > 0.5) {
+ ratio = 0.1
+ print("ratio should be greater than 0 and less than 0.5 setting ratio =
0.1")
+ }
+ # # separate Y
+ Y = data[, ncol(data)]
+ # # get the minority class
+ classes = table(Y, 1)
+ # # # get the minority class
+ minority = as.scalar(rowIndexMin(t(classes)))
+ # # # separate the minority class
+ notMin = (Y != matrix(minority, rows=nrow(Y), cols=1))
+ dX = cbind(seq(1, nrow(data)), data)
+ majority = removeEmpty(target=dX, margin="rows", select=notMin)
+ # # # formulate the undersampling ratio
+ u_ratio = floor(nrow(majority) * ratio)
+ # take the samples for oversampling
+ u_sample = sample(nrow(majority), u_ratio)
+ u_select = table(u_sample, 1, 1, nrow(majority), 1)
+ u_select = u_select * majority[, 1]
+ u_select = removeEmpty(target = u_select, margin = "rows")
+ u_select1 = table(u_select, 1, 1, nrow(data), 1)
+ data = removeEmpty(target=data, margin="rows", select = (u_select1 == 0))
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 6b94b55..cfebfbc 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -309,6 +309,7 @@ public enum Builtins {
TRANSFORMDECODE("transformdecode", false, true),
TRANSFORMENCODE("transformencode", false, true),
TRANSFORMMETA("transformmeta", false, true),
+ UNDER_SAMPLING("underSampling", true),
UPPER_TRI("upper.tri", false, true),
XDUMMY1("xdummy1", true), //error handling test
XDUMMY2("xdummy2", true); //error handling test
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUnderSamplingTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUnderSamplingTest.java
new file mode 100644
index 0000000..38cf610
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUnderSamplingTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinUnderSamplingTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "underSamplingTest";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinUnderSamplingTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B",}));
+ }
+
+ @Test
+ public void test_CP1() {
+
+ runUnderSamplingTest(0.3, Types.ExecType.CP);
+
+ }
+
+ @Test
+ public void test_CP2() {
+
+ runUnderSamplingTest(0.5, Types.ExecType.CP);
+
+ }
+
+ @Test
+ public void test_Spark() {
+ runUnderSamplingTest(0.4,Types.ExecType.SPARK);
+ }
+
+ private void runUnderSamplingTest(double ratio, Types.ExecType
instType) {
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ try {
+ setOutputBuffering(true);
+
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args",
String.valueOf(ratio)};
+
+ String out = runTest(null).toString();
+ Assert.assertTrue(out.contains("TRUE"));
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/underSamplingTest.dml
b/src/test/scripts/functions/builtin/underSamplingTest.dml
new file mode 100644
index 0000000..cdc17dc
--- /dev/null
+++ b/src/test/scripts/functions/builtin/underSamplingTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+ratio = as.double($1)
+X = rand(rows=20, cols=4, min=1, max =100)
+Y = rbind(matrix(1, rows=15, cols=1), matrix(2, rows=5, cols=1))
+data = cbind(X, Y)
+classesUnBalanced = table(Y[, ncol(Y)], 1)
+# # # randomize the data
+IX = sample(nrow(data), nrow(data))
+P = table(seq(1,nrow(IX)), IX, nrow(IX), nrow(data));
+data = P %*% data
+balanced = underSampling(data, ratio)
+classesBalanced = table(balanced[, ncol(balanced)], 1)
+out = as.scalar(classesUnBalanced[1] - classesBalanced[1]) == floor(15.0*ratio)
+print(out)
+
+