This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 86fd7b3  [SYSTEMDS-394] New builtin function toOneHot (one hot 
encoding)
86fd7b3 is described below

commit 86fd7b3d4aae5dbca8090e2638e0abc4da696655
Author: Patrick Deutschmann <[email protected]>
AuthorDate: Sat May 23 22:57:17 2020 +0200

    [SYSTEMDS-394] New builtin function toOneHot (one hot encoding)
    
    Adds a builtin function toOneHot which transforms a vector containing
    integers into a one-hot-encoded matrix (note transform works over frames
    and reassigns the integer codes)
    
    Closes #916.
---
 docs/Tasks.txt                                     |   1 +
 docs/dml-language-reference.md                     |   2 +-
 scripts/builtin/toOneHot.dml                       |  43 ++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../functions/builtin/BuiltinToOneHotTest.java     | 113 +++++++++++++++++++++
 src/test/scripts/functions/builtin/toOneHot.dml    |  25 +++++
 6 files changed, 184 insertions(+), 1 deletion(-)

diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 6539ff8..7b64145 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -302,6 +302,7 @@ SYSTEMDS-390 New Builtin Functions IV
  * 391 New GLM builtin function (from algorithms)                     OK
  * 392 Builtin function for missing value imputation via FDs          OK
  * 393 Builtin to find Connected Components of a graph                OK
+ * 394 Builtin for one-hot encoding of matrix (not frame), see table  OK
 
 Others:
  * Break append instruction to cbind and rbind 
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 652a451..76f2656 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -699,7 +699,7 @@ cummin() | Column prefix-min (For row-prefix min, use 
cummin(t(X)) | Input: matr
 cummax() | Column prefix-max (For row-prefix min, use cummax(t(X)) | Input: 
matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2", 
rows=3, cols=2) <br/> B = cummax(A) <br/> The output matrix B = [[3, 4], [3, 
6], [5, 6]]
 sample(range, size, replacement, seed) | Sample returns a column vector of 
length size, containing uniform random numbers from [1, range] | Input: <br/> 
range: integer <br/> size: integer <br/> replacement: boolean (Optional, 
default: FALSE) <br/> seed: integer (Optional) <br/> Output: Matrix dimensions 
are size x 1 | sample(100, 5) <br/> sample(100, 5, TRUE) <br/> sample(100, 120, 
TRUE) <br/> sample(100, 5, 1234) # 1234 is the seed <br/> sample(100, 5, TRUE, 
1234)
 outer(vector1, vector2, "op") | Applies element wise binary operation "op" 
(for example: "&lt;", "==", "&gt;=", "*", "min") on the all combination of 
vector. <br/> Note: Using "*", we get outer product of two vectors. | Input: 
vectors of same size d, string <br/> Output: matrix of size d X d | A = 
matrix("1 4", rows = 2, cols = 1) <br/> B = matrix("3 6", rows = 1, cols = 2) 
<br/> C = outer(A, B, "&lt;") <br/> D = outer(A, B, "*") <br/> The output 
matrix C = [[1, 1], [0, 1]] <br/> The out [...]
-
+toOneHot(X, num_classes)| Converts a vector containing integers to a 
one-hot-encoded matrix | Input: vector with N integer entries between 1 and 
num_classes, number of columns (must be >= largest value in X)<br />Output: 
one-hot-encoded matrix with shape (N, num_classes) | X = round(rand(rows=10, 
cols=1, min=2, max=10)); <br />num_classes = ​12; <br />Y = toOneHot(X, 
num_classes); 
 
 #### Alternative forms of table()
 
diff --git a/scripts/builtin/toOneHot.dml b/scripts/builtin/toOneHot.dml
new file mode 100644
index 0000000..8134f5c
--- /dev/null
+++ b/scripts/builtin/toOneHot.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# One-hot encodes a vector
+
+# INPUT PARAMETERS:
+# 
--------------------------------------------------------------------------------------------
+# NAME          TYPE    DEFAULT   MEANING
+# 
--------------------------------------------------------------------------------------------
+# X             matrix  ---       vector with N integer entries between 1 and 
numClasses
+# numclasses    int     ---       number of columns, must be >= largest value 
in X
+
+# Output: 
+# 
--------------------------------------------------------------------------------------------
+# NAME          TYPE     MEANING
+# 
-------------------------------------------------------------------------------------------
+# Y             matrix   one-hot-encoded matrix with shape (N, numClasses)
+# 
-------------------------------------------------------------------------------------------
+
+m_toOneHot = function(matrix[double] X, integer numClasses)
+        return (matrix[double] Y) {
+    if(numClasses < max(X))
+      stop("numClasses must be >= largest value in X to prevent cropping");
+    Y = table(seq(1, nrow(X)), X, nrow(X), numClasses);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 6c53692..b09fc63 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -174,6 +174,7 @@ public enum Builtins {
        TAN("tan", false),
        TANH("tanh", false),
        TRACE("trace", false),
+       TO_ONE_HOT("toOneHot", true),
        TYPEOF("typeOf", false),
        VAR("var", false),
        XOR("xor", false),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinToOneHotTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinToOneHotTest.java
new file mode 100644
index 0000000..852f420
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinToOneHotTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+import static org.junit.Assert.fail;
+
+public class BuiltinToOneHotTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "toOneHot";
+       private final static String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinToOneHotTest.class.getSimpleName() + "/";
+
+       private final static double eps = 0;
+       private final static int rows = 10;
+       private final static int cols = 1;
+       private final static int numClasses = 10;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"}));
+       }
+
+       @Test
+       public void runSimpleTest() {
+               runToOneHotTest(false, false, LopProperties.ExecType.CP, false);
+       }
+
+       @Test
+       public void runFailingSimpleTest() {
+               runToOneHotTest(false, false, ExecType.CP, true);
+       }
+
+       private void runToOneHotTest(boolean scalar, boolean sparse, ExecType 
instType, boolean shouldFail) {
+               Types.ExecMode platformOld = setExecMode(instType);
+
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       //generate actual dataset
+                       double[][] A = TestUtils.round(getRandomMatrix(rows, 
cols, 1, numClasses, 1, 7));
+                       int max = -1;
+                       for(int i = 0; i < rows; i++)
+                               max = Math.max(max, (int) A[i][0]);
+                       writeInputMatrixWithMTD("A", A, false);
+
+                       // script fails if numClasses provided is smaller than 
maximum value in A
+                       int numClassesPassed = shouldFail ? max - 1 : max;
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain", "-args", 
input("A"),
+                               String.format("%d", numClassesPassed), 
output("B") };
+
+                       runTest(true, false, null, -1);
+
+                       if(!shouldFail) {
+                               HashMap<MatrixValue.CellIndex, Double> expected 
= computeExpectedResult(A);
+                               HashMap<MatrixValue.CellIndex, Double> result = 
readDMLMatrixFromHDFS("B");
+                               TestUtils.compareMatrices(result, expected, 
eps, "Stat-DML", "Stat-Java");
+                       }
+                       else {
+                               try {
+                                       readDMLMatrixFromHDFS("B");
+                                       fail("File should not have been 
written");
+                               } catch(AssertionError e) {
+                                       // exception expected
+                               }
+                       }
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+
+       private static HashMap<MatrixValue.CellIndex, Double> 
computeExpectedResult(double[][] a) {
+               HashMap<MatrixValue.CellIndex, Double> expected = new 
HashMap<>();
+               for(int i = 0; i < a.length; i++) {
+                       for(int j = 0; j < a[i].length; j++) {
+                               // indices start with 1 here
+                               expected.put(new MatrixValue.CellIndex(i + 1, 
(int) a[i][j]), 1.0);
+                       }
+               }
+               return expected;
+       }
+}
diff --git a/src/test/scripts/functions/builtin/toOneHot.dml 
b/src/test/scripts/functions/builtin/toOneHot.dml
new file mode 100644
index 0000000..fa509d4
--- /dev/null
+++ b/src/test/scripts/functions/builtin/toOneHot.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+num_classes = $2;
+Y = toOneHot(X, num_classes);
+write(Y, $3);

Reply via email to