This is an automated email from the ASF dual-hosted git repository.

ssiddiqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new f4cc41c  [SYSTEMDS-2658] Synthetic Minority Over-sampling Technique 
(SMOTE) Technique for handling unbalance classes by oversampling the minority 
class
f4cc41c is described below

commit f4cc41c8c90dc9787573af36db69cbd3b66a6bd2
Author: Shafaq Siddiqi <[email protected]>
AuthorDate: Wed Sep 2 14:20:50 2020 +0200

    [SYSTEMDS-2658] Synthetic Minority Over-sampling Technique (SMOTE)
    Technique for handling unbalance classes by oversampling the minority class
    
    Date:      Wed Sep 2 14:16:32 2020 +0200
    Closes #988
---
 scripts/builtin/smote.dml                          |  99 ++++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../test/functions/builtin/BuiltinSmoteTest.java   | 103 +++++++++++++++++++++
 src/test/scripts/functions/builtin/smote.dml       |  40 ++++++++
 4 files changed, 243 insertions(+)

diff --git a/scripts/builtin/smote.dml b/scripts/builtin/smote.dml
new file mode 100644
index 0000000..14947ea
--- /dev/null
+++ b/scripts/builtin/smote.dml
@@ -0,0 +1,99 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Builtin function for handing class imbalance using Synthetic Minority 
Over-sampling Technique (SMOTE)
+#
+# INPUT PARAMETERS:
+# 
---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# 
---------------------------------------------------------------------------------------------
+# X               Double   ---       Matrix of minority class samples 
+# s               Integer   25       Amount of SMOTE (percentage of 
oversampling), integral multiple of 100
+# k               Integer   1        Number of nearest neighbour
+# 
---------------------------------------------------------------------------------------------
+
+
+#Output(s)
+# 
---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# 
---------------------------------------------------------------------------------------------
+# Y               Double   ---       Matrix of (N/100)-1 * nrow(X) synthetic 
minority class samples 
+
+m_smote = function(Matrix[Double] X, Integer s = 200, Integer k = 1, Boolean 
verbose = FALSE) 
+return (Matrix[Double] Y) {
+
+  if(s < 100 | (s%%100) != 0)
+  {
+    print("the number of samples should be an integral multiple of 100. 
Setting s = 100")
+    s = 100
+  }
+  # matrix to keep the index of KNN for each minority sample
+  knn_index = matrix(0,k,0)
+  # find nearest neighbour
+  for(i in 1:nrow(X))
+  {
+    knn = nn(X, X[i, ], k)
+    knn_index = cbind(knn_index, knn)
+  }
+  
+  # number of synthetic samples from each minority class sample
+  iter = (s/100)-1
+  # matrix to store synthetic samples
+  synthetic_samples = matrix(0, 0, ncol(X))
+  while(iter > 0)
+  {
+    # generate a random number
+    # TODO avoid duplicate random numbers
+    rand_index = as.integer(as.scalar(Rand(rows=1, cols=1, min=1, max=k)))
+    # pick the random NN
+    knn_sample = knn_index[rand_index,] 
+    # generate sample    
+    for(i in 1:ncol(knn_index))
+    {
+      index = as.scalar(knn_sample[1,i])
+      X_diff = X[index,] - X[i, ]
+      gap = as.scalar(Rand(rows=1, cols=1, min=0, max=1))
+      X_sys = X[i, ] + (gap*X_diff)
+      synthetic_samples = rbind(synthetic_samples, X_sys)
+    }
+    iter = iter - 1
+  }
+
+  Y = synthetic_samples
+}
+  
+
+
+nn = function(Matrix[Double] X, Matrix[Double] instance, Integer k )
+return (Matrix[Double] knn_)
+{
+  if(nrow(X) < k)
+    stop("can not pick "+k+" nearest neighbours from "+nrow(X)+" total 
instances")
+
+  # compute the euclidean distance
+  diff = X - instance
+  square_diff = diff^2
+  distance = sqrt(rowSums(square_diff))
+  sort_dist = order(target = distance, by = 1, decreasing= FALSE, index.return 
=  TRUE)
+  knn_ = sort_dist[2:k+1,]
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 20fad72..28acc5b 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -177,6 +177,7 @@ public enum Builtins {
        SINH("sinh", false),
        STEPLM("steplm",true, ReturnType.MULTI_RETURN),
        SLICEFINDER("slicefinder", true),
+       SMOTE("smote", true),
        SOLVE("solve", false),
        SQRT("sqrt", false),
        SUM("sum", false),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSmoteTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSmoteTest.java
new file mode 100644
index 0000000..675e741
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSmoteTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class BuiltinSmoteTest extends AutomatedTestBase {
+
+       private final static String TEST_NAME = "smote";
+       private final static String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinSmoteTest.class.getSimpleName() + "/";
+
+       private final static int rows = 20;
+       private final static int colsX = 20;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
+       }
+
+       @Test
+       public void testSmote1CP() {
+               runSmoteTest(300, 3, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void testSmote2CP() {
+               runSmoteTest(400, 5, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void testSmote1Spark() {
+               runSmoteTest(300, 3, LopProperties.ExecType.SPARK);
+       }
+
+       @Test
+       public void testSmote2Spark() { runSmoteTest(400, 5, 
LopProperties.ExecType.SPARK);     }
+
+
+       private void runSmoteTest(int sample, int nn, LopProperties.ExecType 
instType) {
+               Types.ExecMode platformOld = setExecMode(instType);
+
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-nvargs", "X=" + 
input("X"), "S=" + sample, "K=" + nn , "Z="+output("Sum"), "T="+input("T")};
+
+                       double[][] X = getRandomMatrix(rows, colsX, 0, 1, 0.3, 
1);
+
+                       writeInputMatrixWithMTD("X", X, true);
+
+                       double[][] T = getRandomMatrix(rows, colsX, 2, 3.0, 
0.3, 3);
+
+                       writeInputMatrixWithMTD("T", T, true);
+
+                       runTest(true, false, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> value = 
readDMLMatrixFromHDFS("Sum");
+                       Assert.assertEquals("synthetic samples does not fall 
into minority class cluster",1,
+                               value.get(new MatrixValue.CellIndex(1,1)), 
0.000001);
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+                       OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
+                       OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+               }
+       }
+}
+
diff --git a/src/test/scripts/functions/builtin/smote.dml 
b/src/test/scripts/functions/builtin/smote.dml
new file mode 100644
index 0000000..dc33f18
--- /dev/null
+++ b/src/test/scripts/functions/builtin/smote.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($X);
+B = smote(X = A, s=$S, k=$K);
+
+# test if all point fall in same cluster (closed to each other)
+# read some new data T != A 
+T = read($T);
+# bind all instanced of minority class
+A_B = rbind(A, B)
+n = nrow(A_B)
+# group data into k=2 clusters
+[C, Y] = kmeans(rbind(A_B, T),  2, 10, 100, 0.000001, FALSE, 50)
+# check if the instances of A and B fall in same cluster
+check = matrix(as.scalar(Y[1,1]), n, 1)
+testSum = sum(check - Y[1:n,])
+# hack for avoiding null pointer exception while reading a single zero in 
HashMap
+testSum = ifelse(testSum == 0, 1, testSum)
+testSum = as.matrix(testSum)
+write(testSum, $Z);  
\ No newline at end of file

Reply via email to