This is an automated email from the ASF dual-hosted git repository.

ssiddiqi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 97e14f8  [SYSTEMDS-3209] Builtin for random under-sampling   - the 
builtin accepts matrix data with last column as labels and a ratio parameter    
 and randomly remove the tuples from the majority class
97e14f8 is described below

commit 97e14f81bbfa440986e670216432afe67cc7c051
Author: Shafaq Siddiqi <[email protected]>
AuthorDate: Mon Nov 8 17:14:01 2021 +0100

    [SYSTEMDS-3209] Builtin for random under-sampling
      - the builtin accepts matrix data with last column as labels and a ratio 
parameter
        and randomly remove the tuples from the majority class
---
 scripts/builtin/underSampling.dml                  | 49 ++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../builtin/BuiltinUnderSamplingTest.java          | 79 ++++++++++++++++++++++
 .../functions/builtin/underSamplingTest.dml        | 36 ++++++++++
 4 files changed, 165 insertions(+)

diff --git a/scripts/builtin/underSampling.dml 
b/scripts/builtin/underSampling.dml
new file mode 100644
index 0000000..debc367
--- /dev/null
+++ b/scripts/builtin/underSampling.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# # # following built-in performs random under sampling on data
+
+underSampling = function(Matrix[Double] data, Double ratio)
+return(Matrix[Double] data)
+{
+  if(ratio < 0 | ratio > 0.5) {
+    ratio = 0.1
+    print("ratio should be greater than 0 and less than 0.5 setting ratio = 
0.1")
+  }
+  # # separate Y
+  Y = data[, ncol(data)]
+  # # get the minority class
+  classes = table(Y, 1)
+  # # # get the minority class
+  minority = as.scalar(rowIndexMin(t(classes)))
+  # # # separate the minority class
+  notMin = (Y != matrix(minority, rows=nrow(Y), cols=1))
+  dX = cbind(seq(1, nrow(data)), data)
+  majority = removeEmpty(target=dX, margin="rows", select=notMin)
+  # # # formulate the undersampling ratio
+  u_ratio = floor(nrow(majority) * ratio)
+  # take the samples for oversampling
+  u_sample = sample(nrow(majority), u_ratio)
+  u_select = table(u_sample, 1, 1, nrow(majority), 1)
+  u_select = u_select * majority[, 1]
+  u_select = removeEmpty(target = u_select, margin = "rows")
+  u_select1 = table(u_select, 1, 1, nrow(data), 1)
+  data = removeEmpty(target=data, margin="rows", select = (u_select1 == 0))
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 6b94b55..cfebfbc 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -309,6 +309,7 @@ public enum Builtins {
        TRANSFORMDECODE("transformdecode", false, true),
        TRANSFORMENCODE("transformencode", false, true),
        TRANSFORMMETA("transformmeta", false, true),
+       UNDER_SAMPLING("underSampling", true),
        UPPER_TRI("upper.tri", false, true),
        XDUMMY1("xdummy1", true), //error handling test
        XDUMMY2("xdummy2", true); //error handling test
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUnderSamplingTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUnderSamplingTest.java
new file mode 100644
index 0000000..38cf610
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUnderSamplingTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinUnderSamplingTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "underSamplingTest";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinUnderSamplingTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B",}));
+       }
+
+       @Test
+       public void test_CP1() {
+
+               runUnderSamplingTest(0.3, Types.ExecType.CP);
+
+       }
+
+       @Test
+       public void test_CP2() {
+
+               runUnderSamplingTest(0.5, Types.ExecType.CP);
+
+       }
+
+       @Test
+       public void test_Spark() {
+               runUnderSamplingTest(0.4,Types.ExecType.SPARK);
+       }
+
+       private void runUnderSamplingTest(double ratio, Types.ExecType 
instType) {
+               Types.ExecMode platformOld = setExecMode(instType);
+
+               try {
+                       setOutputBuffering(true);
+
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-args", 
String.valueOf(ratio)};
+
+                       String out = runTest(null).toString();
+                       Assert.assertTrue(out.contains("TRUE"));
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/underSamplingTest.dml 
b/src/test/scripts/functions/builtin/underSamplingTest.dml
new file mode 100644
index 0000000..cdc17dc
--- /dev/null
+++ b/src/test/scripts/functions/builtin/underSamplingTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+ratio = as.double($1)
+X = rand(rows=20, cols=4, min=1, max =100)
+Y = rbind(matrix(1, rows=15, cols=1), matrix(2, rows=5, cols=1))
+data = cbind(X, Y)
+classesUnBalanced = table(Y[, ncol(Y)], 1)
+# # # randomize the data
+IX = sample(nrow(data), nrow(data))
+P = table(seq(1,nrow(IX)), IX, nrow(IX), nrow(data));
+data = P %*% data
+balanced = underSampling(data, ratio)
+classesBalanced = table(balanced[, ncol(balanced)], 1)
+out = as.scalar(classesUnBalanced[1] - classesBalanced[1]) == floor(15.0*ratio)
+print(out)
+
+

Reply via email to