This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 86fd7b3 [SYSTEMDS-394] New builtin function toOneHot (one hot
encoding)
86fd7b3 is described below
commit 86fd7b3d4aae5dbca8090e2638e0abc4da696655
Author: Patrick Deutschmann <[email protected]>
AuthorDate: Sat May 23 22:57:17 2020 +0200
[SYSTEMDS-394] New builtin function toOneHot (one hot encoding)
Adds a builtin function toOneHot which transforms a vector containing
integers into a one-hot-encoded matrix (note transform works over frames
and reassigns the integer codes)
Closes #916.
---
docs/Tasks.txt | 1 +
docs/dml-language-reference.md | 2 +-
scripts/builtin/toOneHot.dml | 43 ++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../functions/builtin/BuiltinToOneHotTest.java | 113 +++++++++++++++++++++
src/test/scripts/functions/builtin/toOneHot.dml | 25 +++++
6 files changed, 184 insertions(+), 1 deletion(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 6539ff8..7b64145 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -302,6 +302,7 @@ SYSTEMDS-390 New Builtin Functions IV
* 391 New GLM builtin function (from algorithms) OK
* 392 Builtin function for missing value imputation via FDs OK
* 393 Builtin to find Connected Components of a graph OK
+ * 394 Builtin for one-hot encoding of matrix (not frame), see table OK
Others:
* Break append instruction to cbind and rbind
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 652a451..76f2656 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -699,7 +699,7 @@ cummin() | Column prefix-min (For row-prefix min, use
cummin(t(X)) | Input: matr
cummax() | Column prefix-max (For row-prefix min, use cummax(t(X)) | Input:
matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2",
rows=3, cols=2) <br/> B = cummax(A) <br/> The output matrix B = [[3, 4], [3,
6], [5, 6]]
sample(range, size, replacement, seed) | Sample returns a column vector of
length size, containing uniform random numbers from [1, range] | Input: <br/>
range: integer <br/> size: integer <br/> replacement: boolean (Optional,
default: FALSE) <br/> seed: integer (Optional) <br/> Output: Matrix dimensions
are size x 1 | sample(100, 5) <br/> sample(100, 5, TRUE) <br/> sample(100, 120,
TRUE) <br/> sample(100, 5, 1234) # 1234 is the seed <br/> sample(100, 5, TRUE,
1234)
outer(vector1, vector2, "op") | Applies element wise binary operation "op"
(for example: "<", "==", ">=", "*", "min") on the all combination of
vector. <br/> Note: Using "*", we get outer product of two vectors. | Input:
vectors of same size d, string <br/> Output: matrix of size d X d | A =
matrix("1 4", rows = 2, cols = 1) <br/> B = matrix("3 6", rows = 1, cols = 2)
<br/> C = outer(A, B, "<") <br/> D = outer(A, B, "*") <br/> The output
matrix C = [[1, 1], [0, 1]] <br/> The out [...]
-
+toOneHot(X, num_classes)| Converts a vector containing integers to a
one-hot-encoded matrix | Input: vector with N integer entries between 1 and
num_classes, number of columns (must be >= largest value in X)<br />Output:
one-hot-encoded matrix with shape (N, num_classes) | X = round(rand(rows=10,
cols=1, min=2, max=10)); <br />num_classes = 12; <br />Y = toOneHot(X,
num_classes);
#### Alternative forms of table()
diff --git a/scripts/builtin/toOneHot.dml b/scripts/builtin/toOneHot.dml
new file mode 100644
index 0000000..8134f5c
--- /dev/null
+++ b/scripts/builtin/toOneHot.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# One-hot encodes a vector
+
+# INPUT PARAMETERS:
+#
--------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+#
--------------------------------------------------------------------------------------------
+# X matrix --- vector with N integer entries between 1 and
numClasses
+# numclasses int --- number of columns, must be >= largest value
in X
+
+# Output:
+#
--------------------------------------------------------------------------------------------
+# NAME TYPE MEANING
+#
-------------------------------------------------------------------------------------------
+# Y matrix one-hot-encoded matrix with shape (N, numClasses)
+#
-------------------------------------------------------------------------------------------
+
+m_toOneHot = function(matrix[double] X, integer numClasses)
+ return (matrix[double] Y) {
+ if(numClasses < max(X))
+ stop("numClasses must be >= largest value in X to prevent cropping");
+ Y = table(seq(1, nrow(X)), X, nrow(X), numClasses);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 6c53692..b09fc63 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -174,6 +174,7 @@ public enum Builtins {
TAN("tan", false),
TANH("tanh", false),
TRACE("trace", false),
+ TO_ONE_HOT("toOneHot", true),
TYPEOF("typeOf", false),
VAR("var", false),
XOR("xor", false),
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinToOneHotTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinToOneHotTest.java
new file mode 100644
index 0000000..852f420
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinToOneHotTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+import static org.junit.Assert.fail;
+
+public class BuiltinToOneHotTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "toOneHot";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinToOneHotTest.class.getSimpleName() + "/";
+
+ private final static double eps = 0;
+ private final static int rows = 10;
+ private final static int cols = 1;
+ private final static int numClasses = 10;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"}));
+ }
+
+ @Test
+ public void runSimpleTest() {
+ runToOneHotTest(false, false, LopProperties.ExecType.CP, false);
+ }
+
+ @Test
+ public void runFailingSimpleTest() {
+ runToOneHotTest(false, false, ExecType.CP, true);
+ }
+
+ private void runToOneHotTest(boolean scalar, boolean sparse, ExecType
instType, boolean shouldFail) {
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ try
+ {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ //generate actual dataset
+ double[][] A = TestUtils.round(getRandomMatrix(rows,
cols, 1, numClasses, 1, 7));
+ int max = -1;
+ for(int i = 0; i < rows; i++)
+ max = Math.max(max, (int) A[i][0]);
+ writeInputMatrixWithMTD("A", A, false);
+
+ // script fails if numClasses provided is smaller than
maximum value in A
+ int numClassesPassed = shouldFail ? max - 1 : max;
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain", "-args",
input("A"),
+ String.format("%d", numClassesPassed),
output("B") };
+
+ runTest(true, false, null, -1);
+
+ if(!shouldFail) {
+ HashMap<MatrixValue.CellIndex, Double> expected
= computeExpectedResult(A);
+ HashMap<MatrixValue.CellIndex, Double> result =
readDMLMatrixFromHDFS("B");
+ TestUtils.compareMatrices(result, expected,
eps, "Stat-DML", "Stat-Java");
+ }
+ else {
+ try {
+ readDMLMatrixFromHDFS("B");
+ fail("File should not have been
written");
+ } catch(AssertionError e) {
+ // exception expected
+ }
+ }
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ private static HashMap<MatrixValue.CellIndex, Double>
computeExpectedResult(double[][] a) {
+ HashMap<MatrixValue.CellIndex, Double> expected = new
HashMap<>();
+ for(int i = 0; i < a.length; i++) {
+ for(int j = 0; j < a[i].length; j++) {
+ // indices start with 1 here
+ expected.put(new MatrixValue.CellIndex(i + 1,
(int) a[i][j]), 1.0);
+ }
+ }
+ return expected;
+ }
+}
diff --git a/src/test/scripts/functions/builtin/toOneHot.dml
b/src/test/scripts/functions/builtin/toOneHot.dml
new file mode 100644
index 0000000..fa509d4
--- /dev/null
+++ b/src/test/scripts/functions/builtin/toOneHot.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+num_classes = $2;
+Y = toOneHot(X, num_classes);
+write(Y, $3);